Mise à jour de Monitor.py et autres scripts

This commit is contained in:
Debian
2025-07-23 10:46:27 +02:00
parent 7081418ce0
commit 7de3e0fb50
8604 changed files with 2789953 additions and 295 deletions

View File

@@ -0,0 +1,50 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from streamlit.runtime.runtime import Runtime, RuntimeConfig, RuntimeState
from streamlit.runtime.session_manager import (
SessionClient,
SessionClientDisconnectedError,
)
def get_instance() -> Runtime:
"""Return the singleton Runtime instance. Raise an Error if the
Runtime hasn't been created yet.
"""
return Runtime.instance()
def exists() -> bool:
"""True if the singleton Runtime instance has been created.
When a Streamlit app is running in "raw mode" - that is, when the
app is run via `python app.py` instead of `streamlit run app.py` -
the Runtime will not exist, and various Streamlit functions need
to adapt.
"""
return Runtime.exists()
__all__ = [
"Runtime",
"RuntimeConfig",
"RuntimeState",
"SessionClient",
"SessionClientDisconnectedError",
"get_instance",
"exists",
]

View File

@@ -0,0 +1,985 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import json
import sys
import uuid
from enum import Enum
from typing import TYPE_CHECKING, Callable, Final
from google.protobuf.json_format import ParseDict
import streamlit.elements.exception as exception_utils
from streamlit import config, runtime
from streamlit.logger import get_logger
from streamlit.proto.ClientState_pb2 import ClientState
from streamlit.proto.Common_pb2 import FileURLs, FileURLsRequest
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.proto.GitInfo_pb2 import GitInfo
from streamlit.proto.NewSession_pb2 import (
Config,
CustomThemeConfig,
FontFace,
NewSession,
UserInfo,
)
from streamlit.runtime import caching
from streamlit.runtime.forward_msg_queue import ForwardMsgQueue
from streamlit.runtime.fragment import FragmentStorage, MemoryFragmentStorage
from streamlit.runtime.metrics_util import Installation
from streamlit.runtime.pages_manager import PagesManager
from streamlit.runtime.scriptrunner import RerunData, ScriptRunner, ScriptRunnerEvent
from streamlit.runtime.secrets import secrets_singleton
from streamlit.string_util import to_snake_case
from streamlit.version import STREAMLIT_VERSION_STRING
from streamlit.watcher import LocalSourcesWatcher
if TYPE_CHECKING:
from streamlit.proto.BackMsg_pb2 import BackMsg
from streamlit.runtime.script_data import ScriptData
from streamlit.runtime.scriptrunner.script_cache import ScriptCache
from streamlit.runtime.state import SessionState
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
from streamlit.source_util import PageHash, PageInfo
_LOGGER: Final = get_logger(__name__)
class AppSessionState(Enum):
APP_NOT_RUNNING = "APP_NOT_RUNNING"
APP_IS_RUNNING = "APP_IS_RUNNING"
SHUTDOWN_REQUESTED = "SHUTDOWN_REQUESTED"
def _generate_scriptrun_id() -> str:
"""Randomly generate a unique ID for a script execution."""
return str(uuid.uuid4())
class AppSession:
"""
Contains session data for a single "user" of an active app
(that is, a connected browser tab).
Each AppSession has its own ScriptData, root DeltaGenerator, ScriptRunner,
and widget state.
An AppSession is attached to each thread involved in running its script.
"""
def __init__(
self,
script_data: ScriptData,
uploaded_file_manager: UploadedFileManager,
script_cache: ScriptCache,
message_enqueued_callback: Callable[[], None] | None,
user_info: dict[str, str | bool | None],
session_id_override: str | None = None,
) -> None:
"""Initialize the AppSession.
Parameters
----------
script_data
Object storing parameters related to running a script
uploaded_file_manager
Used to manage files uploaded by users via the Streamlit web client.
script_cache
The app's ScriptCache instance. Stores cached user scripts. ScriptRunner
uses the ScriptCache to avoid having to reload user scripts from disk
on each rerun.
message_enqueued_callback
After enqueuing a message, this callable notification will be invoked.
user_info
A dict that contains information about the current user. For now,
it only contains the user's email address.
{
"email": "example@example.com"
}
Information about the current user is optionally provided when a
websocket connection is initialized via the "X-Streamlit-User" header.
session_id_override
The ID to assign to this session. Setting this can be useful when the
service that a Streamlit Runtime is running in wants to tie the lifecycle of
a Streamlit session to some other session-like object that it manages.
"""
# Each AppSession has a unique string ID.
self.id = session_id_override or str(uuid.uuid4())
self._event_loop = asyncio.get_running_loop()
self._script_data = script_data
self._uploaded_file_mgr = uploaded_file_manager
self._script_cache = script_cache
self._pages_manager = PagesManager(
script_data.main_script_path, self._script_cache
)
# The browser queue contains messages that haven't yet been
# delivered to the browser. Periodically, the server flushes
# this queue and delivers its contents to the browser.
self._browser_queue = ForwardMsgQueue()
self._message_enqueued_callback = message_enqueued_callback
self._state = AppSessionState.APP_NOT_RUNNING
# Need to remember the client state here because when a script reruns
# due to the source code changing we need to pass in the previous client state.
self._client_state = ClientState()
self._local_sources_watcher: LocalSourcesWatcher | None = None
self._stop_config_listener: Callable[[], bool] | None = None
self._stop_pages_listener: Callable[[], None] | None = None
if config.get_option("server.fileWatcherType") != "none":
self.register_file_watchers()
self._run_on_save = config.get_option("server.runOnSave")
self._scriptrunner: ScriptRunner | None = None
# This needs to be lazily imported to avoid a dependency cycle.
from streamlit.runtime.state import SessionState
self._session_state = SessionState()
self._user_info = user_info
self._debug_last_backmsg_id: str | None = None
self._fragment_storage: FragmentStorage = MemoryFragmentStorage()
_LOGGER.debug("AppSession initialized (id=%s)", self.id)
def __del__(self) -> None:
"""Ensure that we call shutdown() when an AppSession is garbage collected."""
self.shutdown()
def register_file_watchers(self) -> None:
"""Register handlers to be called when various files are changed.
Files that we watch include:
- source files that already exist (for edits)
- `.py` files in the the main script's `pages/` directory (for file additions
and deletions)
- project and user-level config.toml files
- the project-level secrets.toml files
This method is called automatically on AppSession construction, but it may be
called again in the case when a session is disconnected and is being reconnect
to.
"""
if self._local_sources_watcher is None:
self._local_sources_watcher = LocalSourcesWatcher(self._pages_manager)
self._local_sources_watcher.register_file_change_callback(
self._on_source_file_changed
)
self._stop_config_listener = config.on_config_parsed(
self._on_source_file_changed, force_connect=True
)
secrets_singleton.file_change_listener.connect(self._on_secrets_file_changed)
def disconnect_file_watchers(self) -> None:
"""Disconnect the file watcher handlers registered by register_file_watchers."""
if self._local_sources_watcher is not None:
self._local_sources_watcher.close()
if self._stop_config_listener is not None:
self._stop_config_listener()
if self._stop_pages_listener is not None:
self._stop_pages_listener()
secrets_singleton.file_change_listener.disconnect(self._on_secrets_file_changed)
self._local_sources_watcher = None
self._stop_config_listener = None
self._stop_pages_listener = None
def flush_browser_queue(self) -> list[ForwardMsg]:
"""Clear the forward message queue and return the messages it contained.
The Server calls this periodically to deliver new messages
to the browser connected to this app.
Returns
-------
list[ForwardMsg]
The messages that were removed from the queue and should
be delivered to the browser.
"""
return self._browser_queue.flush()
def shutdown(self) -> None:
"""Shut down the AppSession.
It's an error to use a AppSession after it's been shut down.
"""
if self._state != AppSessionState.SHUTDOWN_REQUESTED:
_LOGGER.debug("Shutting down (id=%s)", self.id)
# Clear any unused session files in upload file manager and media
# file manager
self._uploaded_file_mgr.remove_session_files(self.id)
if runtime.exists():
rt = runtime.get_instance()
rt.media_file_mgr.clear_session_refs(self.id)
rt.media_file_mgr.remove_orphaned_files()
# Shut down the ScriptRunner, if one is active.
# self._state must not be set to SHUTDOWN_REQUESTED until
# *after* this is called.
self.request_script_stop()
self._state = AppSessionState.SHUTDOWN_REQUESTED
# Disconnect all file watchers if we haven't already, although we will have
# generally already done so by the time we get here.
self.disconnect_file_watchers()
def _enqueue_forward_msg(self, msg: ForwardMsg) -> None:
"""Enqueue a new ForwardMsg to our browser queue.
This can be called on both the main thread and a ScriptRunner
run thread.
Parameters
----------
msg : ForwardMsg
The message to enqueue
"""
if self._debug_last_backmsg_id:
msg.debug_last_backmsg_id = self._debug_last_backmsg_id
self._browser_queue.enqueue(msg)
if self._message_enqueued_callback:
self._message_enqueued_callback()
def handle_backmsg(self, msg: BackMsg) -> None:
"""Process a BackMsg."""
try:
msg_type = msg.WhichOneof("type")
if msg_type == "rerun_script":
if msg.debug_last_backmsg_id:
self._debug_last_backmsg_id = msg.debug_last_backmsg_id
self._handle_rerun_script_request(msg.rerun_script)
elif msg_type == "load_git_info":
self._handle_git_information_request()
elif msg_type == "clear_cache":
self._handle_clear_cache_request()
elif msg_type == "app_heartbeat":
self._handle_app_heartbeat_request()
elif msg_type == "set_run_on_save":
self._handle_set_run_on_save_request(msg.set_run_on_save)
elif msg_type == "stop_script":
self._handle_stop_script_request()
elif msg_type == "file_urls_request":
self._handle_file_urls_request(msg.file_urls_request)
else:
_LOGGER.warning('No handler for "%s"', msg_type)
except Exception as ex:
_LOGGER.exception("Error processing back message")
self.handle_backmsg_exception(ex)
def handle_backmsg_exception(self, e: BaseException) -> None:
"""Handle an Exception raised while processing a BackMsg from the browser."""
# This does a few things:
# 1) Clears the current app in the browser.
# 2) Marks the current app as "stopped" in the browser.
# 3) HACK: Resets any script params that may have been broken (e.g. the
# command-line when rerunning with wrong argv[0])
self._on_scriptrunner_event(
self._scriptrunner, ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
)
self._on_scriptrunner_event(
self._scriptrunner,
ScriptRunnerEvent.SCRIPT_STARTED,
page_script_hash="",
)
self._on_scriptrunner_event(
self._scriptrunner, ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
)
# Send an Exception message to the frontend.
# Because _on_scriptrunner_event does its work in an eventloop callback,
# this exception ForwardMsg *must* also be enqueued in a callback,
# so that it will be enqueued *after* the various ForwardMsgs that
# _on_scriptrunner_event sends.
self._event_loop.call_soon_threadsafe(
lambda: self._enqueue_forward_msg(self._create_exception_message(e))
)
def request_rerun(self, client_state: ClientState | None) -> None:
"""Signal that we're interested in running the script.
If the script is not already running, it will be started immediately.
Otherwise, a rerun will be requested.
Parameters
----------
client_state : streamlit.proto.ClientState_pb2.ClientState | None
The ClientState protobuf to run the script with, or None
to use previous client state.
"""
if self._state == AppSessionState.SHUTDOWN_REQUESTED:
_LOGGER.warning("Discarding rerun request after shutdown")
return
if client_state:
fragment_id = client_state.fragment_id
# Early check whether this fragment still exists in the fragment storage or
# might have been removed by a full app run. This is not merely a
# performance optimization, but also fixes following potential situation:
# A fragment run might create a new ScriptRunner when the current
# ScriptRunner is in state STOPPED (in this case, the 'success' variable
# below is false and the new ScriptRunner is created). This will lead to all
# events that were not sent / received from the previous script runner to be
# ignored in _handle_scriptrunner_event_on_event_loop, because the
# _script_runner changed. When the full app rerun ScriptRunner is done
# (STOPPED) but its events are not processed before the new ScriptRunner is
# created, its finished message is not sent to the frontend and no
# full-app-run cleanup is happening. This scenario can be triggered by the
# example app described in
# https://github.com/streamlit/streamlit/issues/9921, where the dialog
# sometimes stays open.
if fragment_id and not self._fragment_storage.contains(fragment_id):
_LOGGER.info(
f"The fragment with id {fragment_id} does not exist anymore - "
"it might have been removed during a preceding full-app rerun."
)
return
if client_state.HasField("context_info"):
self._client_state.context_info.CopyFrom(client_state.context_info)
rerun_data = RerunData(
client_state.query_string,
client_state.widget_states,
client_state.page_script_hash,
client_state.page_name,
fragment_id=fragment_id if fragment_id else None,
is_auto_rerun=client_state.is_auto_rerun,
context_info=client_state.context_info,
)
else:
rerun_data = RerunData()
if self._scriptrunner is not None:
if (
bool(config.get_option("runner.fastReruns"))
and not rerun_data.fragment_id
):
# If fastReruns is enabled and this is *not* a rerun of a fragment,
# we don't send rerun requests to our existing ScriptRunner. Instead, we
# tell it to shut down. We'll then spin up a new ScriptRunner, below, to
# handle the rerun immediately.
self._scriptrunner.request_stop()
self._scriptrunner = None
else:
# Either fastReruns is not enabled or this RERUN request is a request to
# run a fragment. We send our current ScriptRunner a rerun request, and
# if it's accepted, we're done.
success = self._scriptrunner.request_rerun(rerun_data)
if success:
return
# If we are here, then either we have no ScriptRunner, or our
# current ScriptRunner is shutting down and cannot handle a rerun
# request - so we'll create and start a new ScriptRunner.
self._create_scriptrunner(rerun_data)
def request_script_stop(self) -> None:
"""Request that the scriptrunner stop execution.
Does nothing if no scriptrunner exists.
"""
if self._scriptrunner is not None:
self._scriptrunner.request_stop()
def clear_user_info(self) -> None:
"""Clear the user info for this session."""
self._user_info.clear()
def _create_scriptrunner(self, initial_rerun_data: RerunData) -> None:
"""Create and run a new ScriptRunner with the given RerunData."""
self._scriptrunner = ScriptRunner(
session_id=self.id,
main_script_path=self._script_data.main_script_path,
session_state=self._session_state,
uploaded_file_mgr=self._uploaded_file_mgr,
script_cache=self._script_cache,
initial_rerun_data=initial_rerun_data,
user_info=self._user_info,
fragment_storage=self._fragment_storage,
pages_manager=self._pages_manager,
)
self._scriptrunner.on_event.connect(self._on_scriptrunner_event)
self._scriptrunner.start()
@property
def session_state(self) -> SessionState:
return self._session_state
def _should_rerun_on_file_change(self, filepath: str) -> bool:
pages = self._pages_manager.get_pages()
changed_page_script_hash = next(
filter(lambda k: pages[k]["script_path"] == filepath, pages),
None,
)
if changed_page_script_hash is not None:
current_page_script_hash = self._client_state.page_script_hash
return changed_page_script_hash == current_page_script_hash
return True
def _on_source_file_changed(self, filepath: str | None = None) -> None:
"""One of our source files changed. Clear the cache and schedule a rerun if
appropriate.
"""
self._script_cache.clear()
if filepath is not None and not self._should_rerun_on_file_change(filepath):
return
if self._run_on_save:
self.request_rerun(self._client_state)
else:
self._enqueue_forward_msg(self._create_file_change_message())
def _on_secrets_file_changed(self, _) -> None:
"""Called when `secrets.file_change_listener` emits a Signal."""
# NOTE: At the time of writing, this function only calls
# `_on_source_file_changed`. The reason behind creating this function instead of
# just passing `_on_source_file_changed` to `connect` / `disconnect` directly is
# that every function that is passed to `connect` / `disconnect` must have at
# least one argument for `sender` (in this case we don't really care about it,
# thus `_`), and introducing an unnecessary argument to
# `_on_source_file_changed` just for this purpose sounded finicky.
self._on_source_file_changed()
def _clear_queue(self, fragment_ids_this_run: list[str] | None = None) -> None:
self._browser_queue.clear(
retain_lifecycle_msgs=True, fragment_ids_this_run=fragment_ids_this_run
)
def _on_scriptrunner_event(
self,
sender: ScriptRunner | None,
event: ScriptRunnerEvent,
forward_msg: ForwardMsg | None = None,
exception: BaseException | None = None,
client_state: ClientState | None = None,
page_script_hash: str | None = None,
fragment_ids_this_run: list[str] | None = None,
pages: dict[PageHash, PageInfo] | None = None,
) -> None:
"""Called when our ScriptRunner emits an event.
This is generally called from the sender ScriptRunner's script thread.
We forward the event on to _handle_scriptrunner_event_on_event_loop,
which will be called on the main thread.
"""
self._event_loop.call_soon_threadsafe(
lambda: self._handle_scriptrunner_event_on_event_loop(
sender,
event,
forward_msg,
exception,
client_state,
page_script_hash,
fragment_ids_this_run,
pages,
)
)
def _handle_scriptrunner_event_on_event_loop(
self,
sender: ScriptRunner | None,
event: ScriptRunnerEvent,
forward_msg: ForwardMsg | None = None,
exception: BaseException | None = None,
client_state: ClientState | None = None,
page_script_hash: str | None = None,
fragment_ids_this_run: list[str] | None = None,
pages: dict[PageHash, PageInfo] | None = None,
) -> None:
"""Handle a ScriptRunner event.
This function must only be called on our eventloop thread.
Parameters
----------
sender : ScriptRunner | None
The ScriptRunner that emitted the event. (This may be set to
None when called from `handle_backmsg_exception`, if no
ScriptRunner was active when the backmsg exception was raised.)
event : ScriptRunnerEvent
The event type.
forward_msg : ForwardMsg | None
The ForwardMsg to send to the frontend. Set only for the
ENQUEUE_FORWARD_MSG event.
exception : BaseException | None
An exception thrown during compilation. Set only for the
SCRIPT_STOPPED_WITH_COMPILE_ERROR event.
client_state : streamlit.proto.ClientState_pb2.ClientState | None
The ScriptRunner's final ClientState. Set only for the
SHUTDOWN event.
page_script_hash : str | None
A hash of the script path corresponding to the page currently being
run. Set only for the SCRIPT_STARTED event.
fragment_ids_this_run : list[str] | None
The fragment IDs of the fragments being executed in this script run. Only
set for the SCRIPT_STARTED event. If this value is falsy, this script run
must be for the full script.
clear_forward_msg_queue : bool
If set (the default), clears the queue of forward messages to be sent to the
browser. Set only for the SCRIPT_STARTED event.
"""
assert self._event_loop == asyncio.get_running_loop(), (
"This function must only be called on the eventloop thread the AppSession was created on."
)
if sender is not self._scriptrunner:
# This event was sent by a non-current ScriptRunner; ignore it.
# This can happen after sppinng up a new ScriptRunner (to handle a
# rerun request, for example) while another ScriptRunner is still
# shutting down. The shutting-down ScriptRunner may still
# emit events.
_LOGGER.debug("Ignoring event from non-current ScriptRunner: %s", event)
return
prev_state = self._state
if event == ScriptRunnerEvent.SCRIPT_STARTED:
if self._state != AppSessionState.SHUTDOWN_REQUESTED:
self._state = AppSessionState.APP_IS_RUNNING
assert page_script_hash is not None, (
"page_script_hash must be set for the SCRIPT_STARTED event"
)
# Update the client state with the new page_script_hash if
# necessary. This handles an edge case where a script is never
# finishes (eg. by calling st.rerun()), but the page has changed
# via st.navigation()
if page_script_hash != self._client_state.page_script_hash:
self._client_state.page_script_hash = page_script_hash
self._clear_queue(fragment_ids_this_run)
self._enqueue_forward_msg(
self._create_new_session_message(
page_script_hash, fragment_ids_this_run, pages
)
)
elif (
event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
or event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_COMPILE_ERROR
or event == ScriptRunnerEvent.FRAGMENT_STOPPED_WITH_SUCCESS
):
if self._state != AppSessionState.SHUTDOWN_REQUESTED:
self._state = AppSessionState.APP_NOT_RUNNING
if event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS:
status = ForwardMsg.FINISHED_SUCCESSFULLY
elif event == ScriptRunnerEvent.FRAGMENT_STOPPED_WITH_SUCCESS:
status = ForwardMsg.FINISHED_FRAGMENT_RUN_SUCCESSFULLY
else:
status = ForwardMsg.FINISHED_WITH_COMPILE_ERROR
self._enqueue_forward_msg(self._create_script_finished_message(status))
self._debug_last_backmsg_id = None
if (
event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
or event == ScriptRunnerEvent.FRAGMENT_STOPPED_WITH_SUCCESS
):
# The script completed successfully: update our
# LocalSourcesWatcher to account for any source code changes
# that change which modules should be watched.
if self._local_sources_watcher:
self._local_sources_watcher.update_watched_modules()
self._local_sources_watcher.update_watched_pages()
else:
# The script didn't complete successfully: send the exception
# to the frontend.
assert exception is not None, (
"exception must be set for the SCRIPT_STOPPED_WITH_COMPILE_ERROR event"
)
msg = ForwardMsg()
exception_utils.marshall(
msg.session_event.script_compilation_exception, exception
)
self._enqueue_forward_msg(msg)
elif event == ScriptRunnerEvent.SCRIPT_STOPPED_FOR_RERUN:
self._state = AppSessionState.APP_NOT_RUNNING
self._enqueue_forward_msg(
self._create_script_finished_message(
ForwardMsg.FINISHED_EARLY_FOR_RERUN
)
)
if self._local_sources_watcher:
self._local_sources_watcher.update_watched_modules()
elif event == ScriptRunnerEvent.SHUTDOWN:
assert client_state is not None, (
"client_state must be set for the SHUTDOWN event"
)
if self._state == AppSessionState.SHUTDOWN_REQUESTED:
# Only clear media files if the script is done running AND the
# session is actually shutting down.
runtime.get_instance().media_file_mgr.clear_session_refs(self.id)
self._client_state = client_state
self._scriptrunner = None
elif event == ScriptRunnerEvent.ENQUEUE_FORWARD_MSG:
assert forward_msg is not None, (
"null forward_msg in ENQUEUE_FORWARD_MSG event"
)
self._enqueue_forward_msg(forward_msg)
# Send a message if our run state changed
app_was_running = prev_state == AppSessionState.APP_IS_RUNNING
app_is_running = self._state == AppSessionState.APP_IS_RUNNING
if app_is_running != app_was_running:
self._enqueue_forward_msg(self._create_session_status_changed_message())
def _create_session_status_changed_message(self) -> ForwardMsg:
"""Create and return a session_status_changed ForwardMsg."""
msg = ForwardMsg()
msg.session_status_changed.run_on_save = self._run_on_save
msg.session_status_changed.script_is_running = (
self._state == AppSessionState.APP_IS_RUNNING
)
return msg
def _create_file_change_message(self) -> ForwardMsg:
"""Create and return a 'script_changed_on_disk' ForwardMsg."""
msg = ForwardMsg()
msg.session_event.script_changed_on_disk = True
return msg
def _create_new_session_message(
self,
page_script_hash: str,
fragment_ids_this_run: list[str] | None = None,
pages: dict[PageHash, PageInfo] | None = None,
) -> ForwardMsg:
"""Create and return a new_session ForwardMsg."""
msg = ForwardMsg()
msg.new_session.script_run_id = _generate_scriptrun_id()
msg.new_session.name = self._script_data.name
msg.new_session.main_script_path = self._pages_manager.main_script_path
msg.new_session.main_script_hash = self._pages_manager.main_script_hash
msg.new_session.page_script_hash = page_script_hash
if fragment_ids_this_run:
msg.new_session.fragment_ids_this_run.extend(fragment_ids_this_run)
self._populate_app_pages(
msg.new_session, pages or self._pages_manager.get_pages()
)
_populate_config_msg(msg.new_session.config)
_populate_theme_msg(msg.new_session.custom_theme)
_populate_theme_msg(
msg.new_session.custom_theme.sidebar,
f"theme.{config.CustomThemeCategories.SIDEBAR.value}",
)
# Immutable session data. We send this every time a new session is
# started, to avoid having to track whether the client has already
# received it. It does not change from run to run; it's up to the
# to perform one-time initialization only once.
imsg = msg.new_session.initialize
_populate_user_info_msg(imsg.user_info)
imsg.environment_info.streamlit_version = STREAMLIT_VERSION_STRING
imsg.environment_info.python_version = ".".join(map(str, sys.version_info))
imsg.session_status.run_on_save = self._run_on_save
imsg.session_status.script_is_running = (
self._state == AppSessionState.APP_IS_RUNNING
)
imsg.is_hello = self._script_data.is_hello
imsg.session_id = self.id
return msg
def _create_script_finished_message(
self, status: ForwardMsg.ScriptFinishedStatus.ValueType
) -> ForwardMsg:
"""Create and return a script_finished ForwardMsg."""
msg = ForwardMsg()
msg.script_finished = status
return msg
def _create_exception_message(self, e: BaseException) -> ForwardMsg:
"""Create and return an Exception ForwardMsg."""
msg = ForwardMsg()
exception_utils.marshall(msg.delta.new_element.exception, e)
return msg
def _handle_git_information_request(self) -> None:
msg = ForwardMsg()
try:
from streamlit.git_util import GitRepo
repo = GitRepo(self._script_data.main_script_path)
repo_info = repo.get_repo_info()
if repo_info is None:
return
repository_name, branch, module = repo_info
repository_name = repository_name.removesuffix(".git")
msg.git_info_changed.repository = repository_name
msg.git_info_changed.branch = branch
msg.git_info_changed.module = module
msg.git_info_changed.untracked_files[:] = repo.untracked_files
msg.git_info_changed.uncommitted_files[:] = repo.uncommitted_files
if repo.is_head_detached:
msg.git_info_changed.state = GitInfo.GitStates.HEAD_DETACHED
elif len(repo.ahead_commits) > 0:
msg.git_info_changed.state = GitInfo.GitStates.AHEAD_OF_REMOTE
else:
msg.git_info_changed.state = GitInfo.GitStates.DEFAULT
self._enqueue_forward_msg(msg)
except Exception as ex:
# Users may never even install Git in the first place, so this
# error requires no action. It can be useful for debugging.
_LOGGER.debug("Obtaining Git information produced an error", exc_info=ex)
def _handle_rerun_script_request(
self, client_state: ClientState | None = None
) -> None:
"""Tell the ScriptRunner to re-run its script.
Parameters
----------
client_state : streamlit.proto.ClientState_pb2.ClientState | None
The ClientState protobuf to run the script with, or None
to use previous client state.
"""
self.request_rerun(client_state)
def _handle_stop_script_request(self) -> None:
"""Tell the ScriptRunner to stop running its script."""
self.request_script_stop()
def _handle_clear_cache_request(self) -> None:
"""Clear this app's cache.
Because this cache is global, it will be cleared for all users.
"""
caching.cache_data.clear()
caching.cache_resource.clear()
self._session_state.clear()
def _handle_app_heartbeat_request(self) -> None:
"""Handle an incoming app heartbeat.
The heartbeat indicates the frontend is active and keeps the
websocket from going idle and disconnecting.
The actual handler here is a noop
"""
pass
def _handle_set_run_on_save_request(self, new_value: bool) -> None:
"""Change our run_on_save flag to the given value.
The browser will be notified of the change.
Parameters
----------
new_value : bool
New run_on_save value
"""
self._run_on_save = new_value
self._enqueue_forward_msg(self._create_session_status_changed_message())
def _handle_file_urls_request(self, file_urls_request: FileURLsRequest) -> None:
"""Handle a file_urls_request BackMsg sent by the client."""
msg = ForwardMsg()
msg.file_urls_response.response_id = file_urls_request.request_id
upload_url_infos = self._uploaded_file_mgr.get_upload_urls(
self.id, file_urls_request.file_names
)
for upload_url_info in upload_url_infos:
msg.file_urls_response.file_urls.append(
FileURLs(
file_id=upload_url_info.file_id,
upload_url=upload_url_info.upload_url,
delete_url=upload_url_info.delete_url,
)
)
self._enqueue_forward_msg(msg)
def _populate_app_pages(
self, msg: NewSession, pages: dict[PageHash, PageInfo]
) -> None:
for page_script_hash, page_info in pages.items():
page_proto = msg.app_pages.add()
page_proto.page_script_hash = page_script_hash
page_proto.page_name = page_info["page_name"].replace("_", " ")
page_proto.url_pathname = page_info["page_name"]
page_proto.icon = page_info["icon"]
# Config.ToolbarMode.ValueType does not exist at runtime (only in the pyi stubs), so
# we need to use quotes.
# This field will be available at runtime as of protobuf 3.20.1, but
# we are using an older version.
# For details, see: https://github.com/protocolbuffers/protobuf/issues/8175
def _get_toolbar_mode() -> Config.ToolbarMode.ValueType:
config_key = "client.toolbarMode"
config_value = config.get_option(config_key)
enum_value: Config.ToolbarMode.ValueType | None = getattr(
Config.ToolbarMode, config_value.upper()
)
if enum_value is None:
allowed_values = ", ".join(k.lower() for k in Config.ToolbarMode.keys())
raise ValueError(
f"Config {config_key!r} expects to have one of "
f"the following values: {allowed_values}. "
f"Current value: {config_value}"
)
return enum_value
def _populate_config_msg(msg: Config) -> None:
msg.gather_usage_stats = config.get_option("browser.gatherUsageStats")
msg.max_cached_message_age = config.get_option("global.maxCachedMessageAge")
msg.allow_run_on_save = config.get_option("server.allowRunOnSave")
msg.hide_top_bar = config.get_option("ui.hideTopBar")
if config.get_option("client.showSidebarNavigation") is False:
msg.hide_sidebar_nav = True
msg.toolbar_mode = _get_toolbar_mode()
def _populate_theme_msg(msg: CustomThemeConfig, section: str = "theme") -> None:
theme_opts = config.get_options_for_section(section)
if all(val is None for val in theme_opts.values()):
return
for option_name, option_val in theme_opts.items():
# We need to ignore some config options here that need special handling
# and cannot directly be set on the protobuf.
if option_name not in {"base", "font", "fontFaces"} and option_val is not None:
setattr(msg, to_snake_case(option_name), option_val)
# NOTE: If unset, base and font will default to the protobuf enum zero
# values, which are BaseTheme.LIGHT and FontFamily.SANS_SERIF,
# respectively. This is why we both don't handle the cases explicitly and
# also only log a warning when receiving invalid base/font options.
base_map = {
"light": msg.BaseTheme.LIGHT,
"dark": msg.BaseTheme.DARK,
}
base = theme_opts.get("base", None)
if base is not None:
if base not in base_map:
_LOGGER.warning(
f'"{base}" is an invalid value for theme.base.'
f" Allowed values include {list(base_map.keys())}."
' Setting theme.base to "light".'
)
else:
msg.base = base_map[base]
# Since the font field uses the deprecated enum, we need to put the font
# config into the body_font field instead:
body_font = theme_opts.get("font", None)
if body_font:
msg.body_font = body_font
font_faces = theme_opts.get("fontFaces", None)
# If fontFaces was configured via config.toml, it's already a parsed list of
# dictionaries. However, if it was provided via env variable or via CLI arg,
# it's a json string that still needs to be parsed.
if isinstance(font_faces, str):
try:
font_faces = json.loads(font_faces)
except Exception as e:
_LOGGER.warning(
"Failed to parse the theme.fontFaces config option with json.loads: "
f"{font_faces}.",
exc_info=e,
)
font_faces = None
if font_faces is not None:
for font_face in font_faces:
try:
msg.font_faces.append(ParseDict(font_face, FontFace()))
except Exception as e:
_LOGGER.warning(
f"Failed to parse the theme.fontFaces config option: {font_face}.",
exc_info=e,
)
def _populate_user_info_msg(msg: UserInfo) -> None:
msg.installation_id = Installation.instance().installation_id
msg.installation_id_v3 = Installation.instance().installation_id_v3

View File

@@ -0,0 +1,98 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from streamlit.runtime.caching.cache_data_api import (
CACHE_DATA_MESSAGE_REPLAY_CTX,
CacheDataAPI,
get_data_cache_stats_provider,
)
from streamlit.runtime.caching.cache_errors import CACHE_DOCS_URL
from streamlit.runtime.caching.cache_resource_api import (
CACHE_RESOURCE_MESSAGE_REPLAY_CTX,
CacheResourceAPI,
get_resource_cache_stats_provider,
)
from streamlit.runtime.caching.legacy_cache_api import cache as _cache
if TYPE_CHECKING:
from google.protobuf.message import Message
from streamlit.proto.Block_pb2 import Block
def save_element_message(
delta_type: str,
element_proto: Message,
invoked_dg_id: str,
used_dg_id: str,
returned_dg_id: str,
) -> None:
"""Save the message for an element to a thread-local callstack, so it can
be used later to replay the element when a cache-decorated function's
execution is skipped.
"""
CACHE_DATA_MESSAGE_REPLAY_CTX.save_element_message(
delta_type, element_proto, invoked_dg_id, used_dg_id, returned_dg_id
)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_element_message(
delta_type, element_proto, invoked_dg_id, used_dg_id, returned_dg_id
)
def save_block_message(
block_proto: Block,
invoked_dg_id: str,
used_dg_id: str,
returned_dg_id: str,
) -> None:
"""Save the message for a block to a thread-local callstack, so it can
be used later to replay the block when a cache-decorated function's
execution is skipped.
"""
CACHE_DATA_MESSAGE_REPLAY_CTX.save_block_message(
block_proto, invoked_dg_id, used_dg_id, returned_dg_id
)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_block_message(
block_proto, invoked_dg_id, used_dg_id, returned_dg_id
)
def save_media_data(image_data: bytes | str, mimetype: str, image_id: str) -> None:
CACHE_DATA_MESSAGE_REPLAY_CTX.save_image_data(image_data, mimetype, image_id)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_image_data(image_data, mimetype, image_id)
# Create and export public API singletons.
cache_data = CacheDataAPI(decorator_metric_name="cache_data")
cache_resource = CacheResourceAPI(decorator_metric_name="cache_resource")
# TODO(lukasmasuch): This is the legacy cache API name which is deprecated
# and it should be removed in the future.
cache = _cache
__all__ = [
"cache",
"CACHE_DOCS_URL",
"save_element_message",
"save_block_message",
"save_media_data",
"get_data_cache_stats_provider",
"get_resource_cache_stats_provider",
"cache_data",
"cache_resource",
]

View File

@@ -0,0 +1,665 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""@st.cache_data: pickle-based caching."""
from __future__ import annotations
import pickle
import threading
from typing import (
TYPE_CHECKING,
Any,
Callable,
Final,
Literal,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import TypeAlias
import streamlit as st
from streamlit import runtime
from streamlit.errors import StreamlitAPIException
from streamlit.logger import get_logger
from streamlit.runtime.caching.cache_errors import CacheError, CacheKeyNotFoundError
from streamlit.runtime.caching.cache_type import CacheType
from streamlit.runtime.caching.cache_utils import (
Cache,
CachedFuncInfo,
make_cached_func_wrapper,
)
from streamlit.runtime.caching.cached_message_replay import (
CachedMessageReplayContext,
CachedResult,
MsgData,
show_widget_replay_deprecation,
)
from streamlit.runtime.caching.storage import (
CacheStorage,
CacheStorageContext,
CacheStorageError,
CacheStorageKeyNotFoundError,
CacheStorageManager,
)
from streamlit.runtime.caching.storage.cache_storage_protocol import (
InvalidCacheStorageContext,
)
from streamlit.runtime.caching.storage.dummy_cache_storage import (
MemoryCacheStorageManager,
)
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats
from streamlit.time_util import time_to_seconds
if TYPE_CHECKING:
import types
from datetime import timedelta
from streamlit.runtime.caching.hashing import HashFuncsDict
_LOGGER: Final = get_logger(__name__)
CACHE_DATA_MESSAGE_REPLAY_CTX = CachedMessageReplayContext(CacheType.DATA)
# The cache persistence options we support: "disk" or None
CachePersistType: TypeAlias = Union[Literal["disk"], None]
class CachedDataFuncInfo(CachedFuncInfo):
"""Implements the CachedFuncInfo interface for @st.cache_data."""
def __init__(
self,
func: types.FunctionType,
show_spinner: bool | str,
persist: CachePersistType,
max_entries: int | None,
ttl: float | timedelta | str | None,
hash_funcs: HashFuncsDict | None = None,
):
super().__init__(
func,
show_spinner=show_spinner,
hash_funcs=hash_funcs,
)
self.persist = persist
self.max_entries = max_entries
self.ttl = ttl
self.validate_params()
@property
def cache_type(self) -> CacheType:
return CacheType.DATA
@property
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
return CACHE_DATA_MESSAGE_REPLAY_CTX
@property
def display_name(self) -> str:
"""A human-readable name for the cached function."""
return f"{self.func.__module__}.{self.func.__qualname__}"
def get_function_cache(self, function_key: str) -> Cache:
return _data_caches.get_cache(
key=function_key,
persist=self.persist,
max_entries=self.max_entries,
ttl=self.ttl,
display_name=self.display_name,
)
def validate_params(self) -> None:
"""
Validate the params passed to @st.cache_data are compatible with cache storage.
When called, this method could log warnings if cache params are invalid
for current storage.
"""
_data_caches.validate_cache_params(
function_name=self.func.__name__,
persist=self.persist,
max_entries=self.max_entries,
ttl=self.ttl,
)
class DataCaches(CacheStatsProvider):
"""Manages all DataCache instances."""
def __init__(self):
self._caches_lock = threading.Lock()
self._function_caches: dict[str, DataCache] = {}
def get_cache(
self,
key: str,
persist: CachePersistType,
max_entries: int | None,
ttl: int | float | timedelta | str | None,
display_name: str,
) -> DataCache:
"""Return the mem cache for the given key.
If it doesn't exist, create a new one with the given params.
"""
ttl_seconds = time_to_seconds(ttl, coerce_none_to_inf=False)
# Get the existing cache, if it exists, and validate that its params
# haven't changed.
with self._caches_lock:
cache = self._function_caches.get(key)
if (
cache is not None
and cache.ttl_seconds == ttl_seconds
and cache.max_entries == max_entries
and cache.persist == persist
):
return cache
# Close the existing cache's storage, if it exists.
if cache is not None:
_LOGGER.debug(
"Closing existing DataCache storage "
"(key=%s, persist=%s, max_entries=%s, ttl=%s) "
"before creating new one with different params",
key,
persist,
max_entries,
ttl,
)
cache.storage.close()
# Create a new cache object and put it in our dict
_LOGGER.debug(
"Creating new DataCache (key=%s, persist=%s, max_entries=%s, ttl=%s)",
key,
persist,
max_entries,
ttl,
)
cache_context = self.create_cache_storage_context(
function_key=key,
function_name=display_name,
ttl_seconds=ttl_seconds,
max_entries=max_entries,
persist=persist,
)
cache_storage_manager = self.get_storage_manager()
storage = cache_storage_manager.create(cache_context)
cache = DataCache(
key=key,
storage=storage,
persist=persist,
max_entries=max_entries,
ttl_seconds=ttl_seconds,
display_name=display_name,
)
self._function_caches[key] = cache
return cache
def clear_all(self) -> None:
"""Clear all in-memory and on-disk caches."""
with self._caches_lock:
try:
# try to remove in optimal way if such ability provided by
# storage manager clear_all method;
# if not implemented, fallback to remove all
# available storages one by one
self.get_storage_manager().clear_all()
except NotImplementedError:
for data_cache in self._function_caches.values():
data_cache.clear()
data_cache.storage.close()
self._function_caches = {}
def get_stats(self) -> list[CacheStat]:
with self._caches_lock:
# Shallow-clone our caches. We don't want to hold the global
# lock during stats-gathering.
function_caches = self._function_caches.copy()
stats: list[CacheStat] = []
for cache in function_caches.values():
stats.extend(cache.get_stats())
return group_stats(stats)
def validate_cache_params(
self,
function_name: str,
persist: CachePersistType,
max_entries: int | None,
ttl: int | float | timedelta | str | None,
) -> None:
"""Validate that the cache params are valid for given storage.
Raises
------
InvalidCacheStorageContext
Raised if the cache storage manager is not able to work with provided
CacheStorageContext.
"""
ttl_seconds = time_to_seconds(ttl, coerce_none_to_inf=False)
cache_context = self.create_cache_storage_context(
function_key="DUMMY_KEY",
function_name=function_name,
ttl_seconds=ttl_seconds,
max_entries=max_entries,
persist=persist,
)
try:
self.get_storage_manager().check_context(cache_context)
except InvalidCacheStorageContext as e:
_LOGGER.error(
"Cache params for function %s are incompatible with current "
"cache storage manager.",
function_name,
exc_info=e,
)
raise
def create_cache_storage_context(
self,
function_key: str,
function_name: str,
persist: CachePersistType,
ttl_seconds: float | None,
max_entries: int | None,
) -> CacheStorageContext:
return CacheStorageContext(
function_key=function_key,
function_display_name=function_name,
ttl_seconds=ttl_seconds,
max_entries=max_entries,
persist=persist,
)
def get_storage_manager(self) -> CacheStorageManager:
if runtime.exists():
return runtime.get_instance().cache_storage_manager
else:
# When running in "raw mode", we can't access the CacheStorageManager,
# so we're falling back to InMemoryCache.
_LOGGER.warning("No runtime found, using MemoryCacheStorageManager")
return MemoryCacheStorageManager()
# Singleton DataCaches instance
_data_caches = DataCaches()
def get_data_cache_stats_provider() -> CacheStatsProvider:
"""Return the StatsProvider for all @st.cache_data functions."""
return _data_caches
class CacheDataAPI:
"""Implements the public st.cache_data API: the @st.cache_data decorator, and
st.cache_data.clear().
"""
def __init__(self, decorator_metric_name: str):
"""Create a CacheDataAPI instance.
Parameters
----------
decorator_metric_name
The metric name to record for decorator usage.
"""
# Parameterize the decorator metric name.
# (Ignore spurious mypy complaints - https://github.com/python/mypy/issues/2427)
self._decorator = gather_metrics( # type: ignore
decorator_metric_name, self._decorator
)
# Type-annotate the decorator function.
# (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories)
F = TypeVar("F", bound=Callable[..., Any])
# Bare decorator usage
@overload
def __call__(self, func: F) -> F: ...
# Decorator with arguments
@overload
def __call__(
self,
*,
ttl: float | timedelta | str | None = None,
max_entries: int | None = None,
show_spinner: bool | str = True,
persist: CachePersistType | bool = None,
experimental_allow_widgets: bool = False,
hash_funcs: HashFuncsDict | None = None,
) -> Callable[[F], F]: ...
def __call__(
self,
func: F | None = None,
*,
ttl: float | timedelta | str | None = None,
max_entries: int | None = None,
show_spinner: bool | str = True,
persist: CachePersistType | bool = None,
experimental_allow_widgets: bool = False,
hash_funcs: HashFuncsDict | None = None,
):
return self._decorator(
func,
ttl=ttl,
max_entries=max_entries,
persist=persist,
show_spinner=show_spinner,
experimental_allow_widgets=experimental_allow_widgets,
hash_funcs=hash_funcs,
)
def _decorator(
self,
func: F | None = None,
*,
ttl: float | timedelta | str | None,
max_entries: int | None,
show_spinner: bool | str,
persist: CachePersistType | bool,
experimental_allow_widgets: bool,
hash_funcs: HashFuncsDict | None = None,
):
"""Decorator to cache functions that return data (e.g. dataframe transforms, database queries, ML inference).
Cached objects are stored in "pickled" form, which means that the return
value of a cached function must be pickleable. Each caller of the cached
function gets its own copy of the cached data.
You can clear a function's cache with ``func.clear()`` or clear the entire
cache with ``st.cache_data.clear()``.
A function's arguments must be hashable to cache it. If you have an
unhashable argument (like a database connection) or an argument you
want to exclude from caching, use an underscore prefix in the argument
name. In this case, Streamlit will return a cached value when all other
arguments match a previous function call. Alternatively, you can
declare custom hashing functions with ``hash_funcs``.
To cache global resources, use ``st.cache_resource`` instead. Learn more
about caching at https://docs.streamlit.io/develop/concepts/architecture/caching.
Parameters
----------
func : callable
The function to cache. Streamlit hashes the function's source code.
ttl : float, timedelta, str, or None
The maximum time to keep an entry in the cache. Can be one of:
- ``None`` if cache entries should never expire (default).
- A number specifying the time in seconds.
- A string specifying the time in a format supported by `Pandas's
Timedelta constructor <https://pandas.pydata.org/docs/reference/api/pandas.Timedelta.html>`_,
e.g. ``"1d"``, ``"1.5 days"``, or ``"1h23s"``.
- A ``timedelta`` object from `Python's built-in datetime library
<https://docs.python.org/3/library/datetime.html#timedelta-objects>`_,
e.g. ``timedelta(days=1)``.
Note that ``ttl`` will be ignored if ``persist="disk"`` or ``persist=True``.
max_entries : int or None
The maximum number of entries to keep in the cache, or None
for an unbounded cache. When a new entry is added to a full cache,
the oldest cached entry will be removed. Defaults to None.
show_spinner : bool or str
Enable the spinner. Default is True to show a spinner when there is
a "cache miss" and the cached data is being created. If string,
value of show_spinner param will be used for spinner text.
persist : "disk", bool, or None
Optional location to persist cached data to. Passing "disk" (or True)
will persist the cached data to the local disk. None (or False) will disable
persistence. The default is None.
experimental_allow_widgets : bool
Allow widgets to be used in the cached function. Defaults to False.
hash_funcs : dict or None
Mapping of types or fully qualified names to hash functions.
This is used to override the behavior of the hasher inside Streamlit's
caching mechanism: when the hasher encounters an object, it will first
check to see if its type matches a key in this dict and, if so, will use
the provided function to generate a hash for it. See below for an example
of how this can be used.
.. deprecated::
The cached widget replay functionality was removed in 1.38. Please
remove the ``experimental_allow_widgets`` parameter from your
caching decorators. This parameter will be removed in a future
version.
Example
-------
>>> import streamlit as st
>>>
>>> @st.cache_data
... def fetch_and_clean_data(url):
... # Fetch data from URL here, and then clean it up.
... return data
>>>
>>> d1 = fetch_and_clean_data(DATA_URL_1)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> d2 = fetch_and_clean_data(DATA_URL_1)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value. This means that now the data in d1 is the same as in d2.
>>>
>>> d3 = fetch_and_clean_data(DATA_URL_2)
>>> # This is a different URL, so the function executes.
To set the ``persist`` parameter, use this command as follows:
>>> import streamlit as st
>>>
>>> @st.cache_data(persist="disk")
... def fetch_and_clean_data(url):
... # Fetch data from URL here, and then clean it up.
... return data
By default, all parameters to a cached function must be hashable.
Any parameter whose name begins with ``_`` will not be hashed. You can use
this as an "escape hatch" for parameters that are not hashable:
>>> import streamlit as st
>>>
>>> @st.cache_data
... def fetch_and_clean_data(_db_connection, num_rows):
... # Fetch data from _db_connection here, and then clean it up.
... return data
>>>
>>> connection = make_database_connection()
>>> d1 = fetch_and_clean_data(connection, num_rows=10)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> another_connection = make_database_connection()
>>> d2 = fetch_and_clean_data(another_connection, num_rows=10)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value - even though the _database_connection parameter was different
>>> # in both calls.
A cached function's cache can be procedurally cleared:
>>> import streamlit as st
>>>
>>> @st.cache_data
... def fetch_and_clean_data(_db_connection, num_rows):
... # Fetch data from _db_connection here, and then clean it up.
... return data
>>>
>>> fetch_and_clean_data.clear(_db_connection, 50)
>>> # Clear the cached entry for the arguments provided.
>>>
>>> fetch_and_clean_data.clear()
>>> # Clear all cached entries for this function.
To override the default hashing behavior, pass a custom hash function.
You can do that by mapping a type (e.g. ``datetime.datetime``) to a hash
function (``lambda dt: dt.isoformat()``) like this:
>>> import streamlit as st
>>> import datetime
>>>
>>> @st.cache_data(hash_funcs={datetime.datetime: lambda dt: dt.isoformat()})
... def convert_to_utc(dt: datetime.datetime):
... return dt.astimezone(datetime.timezone.utc)
Alternatively, you can map the type's fully-qualified name
(e.g. ``"datetime.datetime"``) to the hash function instead:
>>> import streamlit as st
>>> import datetime
>>>
>>> @st.cache_data(hash_funcs={"datetime.datetime": lambda dt: dt.isoformat()})
... def convert_to_utc(dt: datetime.datetime):
... return dt.astimezone(datetime.timezone.utc)
"""
# Parse our persist value into a string
persist_string: CachePersistType
if persist is True:
persist_string = "disk"
elif persist is False:
persist_string = None
else:
persist_string = persist
if persist_string not in (None, "disk"):
# We'll eventually have more persist options.
raise StreamlitAPIException(
f"Unsupported persist option '{persist}'. Valid values are 'disk' or None."
)
if experimental_allow_widgets:
show_widget_replay_deprecation("cache_data")
def wrapper(f):
return make_cached_func_wrapper(
CachedDataFuncInfo(
func=f,
persist=persist_string,
show_spinner=show_spinner,
max_entries=max_entries,
ttl=ttl,
hash_funcs=hash_funcs,
)
)
if func is None:
return wrapper
return make_cached_func_wrapper(
CachedDataFuncInfo(
func=cast("types.FunctionType", func),
persist=persist_string,
show_spinner=show_spinner,
max_entries=max_entries,
ttl=ttl,
hash_funcs=hash_funcs,
)
)
@gather_metrics("clear_data_caches")
def clear(self) -> None:
"""Clear all in-memory and on-disk data caches."""
_data_caches.clear_all()
class DataCache(Cache):
"""Manages cached values for a single st.cache_data function."""
def __init__(
self,
key: str,
storage: CacheStorage,
persist: CachePersistType,
max_entries: int | None,
ttl_seconds: float | None,
display_name: str,
):
super().__init__()
self.key = key
self.display_name = display_name
self.storage = storage
self.ttl_seconds = ttl_seconds
self.max_entries = max_entries
self.persist = persist
def get_stats(self) -> list[CacheStat]:
if isinstance(self.storage, CacheStatsProvider):
return self.storage.get_stats()
return []
def read_result(self, key: str) -> CachedResult:
"""Read a value and messages from the cache. Raise `CacheKeyNotFoundError`
if the value doesn't exist, and `CacheError` if the value exists but can't
be unpickled.
"""
try:
pickled_entry = self.storage.get(key)
except CacheStorageKeyNotFoundError as e:
raise CacheKeyNotFoundError(str(e)) from e
except CacheStorageError as e:
raise CacheError(str(e)) from e
try:
entry = pickle.loads(pickled_entry)
if not isinstance(entry, CachedResult):
# Loaded an old cache file format, remove it and let the caller
# rerun the function.
self.storage.delete(key)
raise CacheKeyNotFoundError()
return entry
except pickle.UnpicklingError as exc:
raise CacheError(f"Failed to unpickle {key}") from exc
@gather_metrics("_cache_data_object")
def write_result(self, key: str, value: Any, messages: list[MsgData]) -> None:
"""Write a value and associated messages to the cache.
The value must be pickleable.
"""
try:
main_id = st._main.id
sidebar_id = st.sidebar.id
entry = CachedResult(value, messages, main_id, sidebar_id)
pickled_entry = pickle.dumps(entry)
except (pickle.PicklingError, TypeError) as exc:
raise CacheError(f"Failed to pickle {key}") from exc
self.storage.set(key, pickled_entry)
def _clear(self, key: str | None = None) -> None:
if not key:
self.storage.clear()
else:
self.storage.delete(key)

View File

@@ -0,0 +1,142 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from streamlit import type_util
from streamlit.errors import MarkdownFormattedException, StreamlitAPIException
from streamlit.runtime.caching.cache_type import CacheType, get_decorator_api_name
if TYPE_CHECKING:
from types import FunctionType
CACHE_DOCS_URL = "https://docs.streamlit.io/develop/concepts/architecture/caching"
def get_cached_func_name_md(func: Any) -> str:
"""Get markdown representation of the function name."""
if hasattr(func, "__name__"):
return f"`{func.__name__}()`"
elif hasattr(type(func), "__name__"):
return f"`{type(func).__name__}`"
return f"`{type(func)}`"
def get_return_value_type(return_value: Any) -> str:
if hasattr(return_value, "__module__") and hasattr(type(return_value), "__name__"):
return f"`{return_value.__module__}.{type(return_value).__name__}`"
return get_cached_func_name_md(return_value)
class UnhashableTypeError(Exception):
pass
class UnhashableParamError(StreamlitAPIException):
def __init__(
self,
cache_type: CacheType,
func: FunctionType,
arg_name: str | None,
arg_value: Any,
orig_exc: BaseException,
):
msg = self._create_message(cache_type, func, arg_name, arg_value)
super().__init__(msg)
self.with_traceback(orig_exc.__traceback__)
@staticmethod
def _create_message(
cache_type: CacheType,
func: FunctionType,
arg_name: str | None,
arg_value: Any,
) -> str:
arg_name_str = arg_name if arg_name is not None else "(unnamed)"
arg_type = type_util.get_fqn_type(arg_value)
func_name = func.__name__
arg_replacement_name = f"_{arg_name}" if arg_name is not None else "_arg"
return (
f"""
Cannot hash argument '{arg_name_str}' (of type `{arg_type}`) in '{func_name}'.
To address this, you can tell Streamlit not to hash this argument by adding a
leading underscore to the argument's name in the function signature:
```
@st.{get_decorator_api_name(cache_type)}
def {func_name}({arg_replacement_name}, ...):
...
```
"""
).strip("\n")
class CacheKeyNotFoundError(Exception):
pass
class CacheError(Exception):
pass
class CacheReplayClosureError(StreamlitAPIException):
def __init__(
self,
cache_type: CacheType,
cached_func: FunctionType,
):
func_name = get_cached_func_name_md(cached_func)
decorator_name = get_decorator_api_name(cache_type)
msg = (
f"""
While running {func_name}, a streamlit element is called on some layout block
created outside the function. This is incompatible with replaying the cached
effect of that element, because the the referenced block might not exist when
the replay happens.
How to fix this:
* Move the creation of $THING inside {func_name}.
* Move the call to the streamlit element outside of {func_name}.
* Remove the `@st.{decorator_name}` decorator from {func_name}.
"""
).strip("\n")
super().__init__(msg)
class UnserializableReturnValueError(MarkdownFormattedException):
def __init__(self, func: FunctionType, return_value: FunctionType):
MarkdownFormattedException.__init__(
self,
f"""
Cannot serialize the return value (of type {get_return_value_type(return_value)})
in {get_cached_func_name_md(func)}. `st.cache_data` uses
[pickle](https://docs.python.org/3/library/pickle.html) to serialize the
function's return value and safely store it in the cache
without mutating the original object. Please convert the return value to a
pickle-serializable type. If you want to cache unserializable objects such
as database connections or Tensorflow sessions, use `st.cache_resource`
instead (see [our docs]({CACHE_DOCS_URL}) for differences).""",
)
class UnevaluatedDataFrameError(StreamlitAPIException):
"""Used to display a message about uncollected dataframe being used."""
pass

View File

@@ -0,0 +1,527 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""@st.cache_resource implementation."""
from __future__ import annotations
import math
import threading
from typing import TYPE_CHECKING, Any, Callable, Final, TypeVar, cast, overload
from cachetools import TTLCache
from typing_extensions import TypeAlias
import streamlit as st
from streamlit.logger import get_logger
from streamlit.runtime.caching import cache_utils
from streamlit.runtime.caching.cache_errors import CacheKeyNotFoundError
from streamlit.runtime.caching.cache_type import CacheType
from streamlit.runtime.caching.cache_utils import (
Cache,
CachedFuncInfo,
make_cached_func_wrapper,
)
from streamlit.runtime.caching.cached_message_replay import (
CachedMessageReplayContext,
CachedResult,
MsgData,
show_widget_replay_deprecation,
)
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats
from streamlit.time_util import time_to_seconds
if TYPE_CHECKING:
import types
from datetime import timedelta
from streamlit.runtime.caching.hashing import HashFuncsDict
_LOGGER: Final = get_logger(__name__)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX = CachedMessageReplayContext(CacheType.RESOURCE)
ValidateFunc: TypeAlias = Callable[[Any], bool]
def _equal_validate_funcs(a: ValidateFunc | None, b: ValidateFunc | None) -> bool:
"""True if the two validate functions are equal for the purposes of
determining whether a given function cache needs to be recreated.
"""
# To "properly" test for function equality here, we'd need to compare function bytecode.
# For performance reasons, We've decided not to do that for now.
return (a is None and b is None) or (a is not None and b is not None)
class ResourceCaches(CacheStatsProvider):
"""Manages all ResourceCache instances."""
def __init__(self):
self._caches_lock = threading.Lock()
self._function_caches: dict[str, ResourceCache] = {}
def get_cache(
self,
key: str,
display_name: str,
max_entries: int | float | None,
ttl: float | timedelta | str | None,
validate: ValidateFunc | None,
) -> ResourceCache:
"""Return the mem cache for the given key.
If it doesn't exist, create a new one with the given params.
"""
if max_entries is None:
max_entries = math.inf
ttl_seconds = time_to_seconds(ttl)
# Get the existing cache, if it exists, and validate that its params
# haven't changed.
with self._caches_lock:
cache = self._function_caches.get(key)
if (
cache is not None
and cache.ttl_seconds == ttl_seconds
and cache.max_entries == max_entries
and _equal_validate_funcs(cache.validate, validate)
):
return cache
# Create a new cache object and put it in our dict
_LOGGER.debug("Creating new ResourceCache (key=%s)", key)
cache = ResourceCache(
key=key,
display_name=display_name,
max_entries=max_entries,
ttl_seconds=ttl_seconds,
validate=validate,
)
self._function_caches[key] = cache
return cache
def clear_all(self) -> None:
"""Clear all resource caches."""
with self._caches_lock:
self._function_caches = {}
def get_stats(self) -> list[CacheStat]:
with self._caches_lock:
# Shallow-clone our caches. We don't want to hold the global
# lock during stats-gathering.
function_caches = self._function_caches.copy()
stats: list[CacheStat] = []
for cache in function_caches.values():
stats.extend(cache.get_stats())
return group_stats(stats)
# Singleton ResourceCaches instance
_resource_caches = ResourceCaches()
def get_resource_cache_stats_provider() -> CacheStatsProvider:
"""Return the StatsProvider for all @st.cache_resource functions."""
return _resource_caches
class CachedResourceFuncInfo(CachedFuncInfo):
"""Implements the CachedFuncInfo interface for @st.cache_resource."""
def __init__(
self,
func: types.FunctionType,
show_spinner: bool | str,
max_entries: int | None,
ttl: float | timedelta | str | None,
validate: ValidateFunc | None,
hash_funcs: HashFuncsDict | None = None,
):
super().__init__(
func,
show_spinner=show_spinner,
hash_funcs=hash_funcs,
)
self.max_entries = max_entries
self.ttl = ttl
self.validate = validate
@property
def cache_type(self) -> CacheType:
return CacheType.RESOURCE
@property
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
return CACHE_RESOURCE_MESSAGE_REPLAY_CTX
@property
def display_name(self) -> str:
"""A human-readable name for the cached function."""
return f"{self.func.__module__}.{self.func.__qualname__}"
def get_function_cache(self, function_key: str) -> Cache:
return _resource_caches.get_cache(
key=function_key,
display_name=self.display_name,
max_entries=self.max_entries,
ttl=self.ttl,
validate=self.validate,
)
class CacheResourceAPI:
"""Implements the public st.cache_resource API: the @st.cache_resource decorator,
and st.cache_resource.clear().
"""
def __init__(self, decorator_metric_name: str):
"""Create a CacheResourceAPI instance.
Parameters
----------
decorator_metric_name
The metric name to record for decorator usage.
"""
# Parameterize the decorator metric name.
# (Ignore spurious mypy complaints - https://github.com/python/mypy/issues/2427)
self._decorator = gather_metrics(decorator_metric_name, self._decorator) # type: ignore
# Type-annotate the decorator function.
# (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories)
F = TypeVar("F", bound=Callable[..., Any])
# Bare decorator usage
@overload
def __call__(self, func: F) -> F: ...
# Decorator with arguments
@overload
def __call__(
self,
*,
ttl: float | timedelta | str | None = None,
max_entries: int | None = None,
show_spinner: bool | str = True,
validate: ValidateFunc | None = None,
experimental_allow_widgets: bool = False,
hash_funcs: HashFuncsDict | None = None,
) -> Callable[[F], F]: ...
def __call__(
self,
func: F | None = None,
*,
ttl: float | timedelta | str | None = None,
max_entries: int | None = None,
show_spinner: bool | str = True,
validate: ValidateFunc | None = None,
experimental_allow_widgets: bool = False,
hash_funcs: HashFuncsDict | None = None,
):
return self._decorator(
func,
ttl=ttl,
max_entries=max_entries,
show_spinner=show_spinner,
validate=validate,
experimental_allow_widgets=experimental_allow_widgets,
hash_funcs=hash_funcs,
)
def _decorator(
self,
func: F | None,
*,
ttl: float | timedelta | str | None,
max_entries: int | None,
show_spinner: bool | str,
validate: ValidateFunc | None,
experimental_allow_widgets: bool,
hash_funcs: HashFuncsDict | None = None,
):
"""Decorator to cache functions that return global resources (e.g. database connections, ML models).
Cached objects are shared across all users, sessions, and reruns. They
must be thread-safe because they can be accessed from multiple threads
concurrently. If thread safety is an issue, consider using ``st.session_state``
to store resources per session instead.
You can clear a function's cache with ``func.clear()`` or clear the entire
cache with ``st.cache_resource.clear()``.
A function's arguments must be hashable to cache it. If you have an
unhashable argument (like a database connection) or an argument you
want to exclude from caching, use an underscore prefix in the argument
name. In this case, Streamlit will return a cached value when all other
arguments match a previous function call. Alternatively, you can
declare custom hashing functions with ``hash_funcs``.
To cache data, use ``st.cache_data`` instead. Learn more about caching at
https://docs.streamlit.io/develop/concepts/architecture/caching.
Parameters
----------
func : callable
The function that creates the cached resource. Streamlit hashes the
function's source code.
ttl : float, timedelta, str, or None
The maximum time to keep an entry in the cache. Can be one of:
- ``None`` if cache entries should never expire (default).
- A number specifying the time in seconds.
- A string specifying the time in a format supported by `Pandas's
Timedelta constructor <https://pandas.pydata.org/docs/reference/api/pandas.Timedelta.html>`_,
e.g. ``"1d"``, ``"1.5 days"``, or ``"1h23s"``.
- A ``timedelta`` object from `Python's built-in datetime library
<https://docs.python.org/3/library/datetime.html#timedelta-objects>`_,
e.g. ``timedelta(days=1)``.
max_entries : int or None
The maximum number of entries to keep in the cache, or None
for an unbounded cache. When a new entry is added to a full cache,
the oldest cached entry will be removed. Defaults to None.
show_spinner : bool or str
Enable the spinner. Default is True to show a spinner when there is
a "cache miss" and the cached resource is being created. If string,
value of show_spinner param will be used for spinner text.
validate : callable or None
An optional validation function for cached data. ``validate`` is called
each time the cached value is accessed. It receives the cached value as
its only parameter and it must return a boolean. If ``validate`` returns
False, the current cached value is discarded, and the decorated function
is called to compute a new value. This is useful e.g. to check the
health of database connections.
experimental_allow_widgets : bool
Allow widgets to be used in the cached function. Defaults to False.
hash_funcs : dict or None
Mapping of types or fully qualified names to hash functions.
This is used to override the behavior of the hasher inside Streamlit's
caching mechanism: when the hasher encounters an object, it will first
check to see if its type matches a key in this dict and, if so, will use
the provided function to generate a hash for it. See below for an example
of how this can be used.
.. deprecated::
The cached widget replay functionality was removed in 1.38. Please
remove the ``experimental_allow_widgets`` parameter from your
caching decorators. This parameter will be removed in a future
version.
Example
-------
>>> import streamlit as st
>>>
>>> @st.cache_resource
... def get_database_session(url):
... # Create a database session object that points to the URL.
... return session
>>>
>>> s1 = get_database_session(SESSION_URL_1)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> s2 = get_database_session(SESSION_URL_1)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value. This means that now the connection object in s1 is the same as in s2.
>>>
>>> s3 = get_database_session(SESSION_URL_2)
>>> # This is a different URL, so the function executes.
By default, all parameters to a cache_resource function must be hashable.
Any parameter whose name begins with ``_`` will not be hashed. You can use
this as an "escape hatch" for parameters that are not hashable:
>>> import streamlit as st
>>>
>>> @st.cache_resource
... def get_database_session(_sessionmaker, url):
... # Create a database connection object that points to the URL.
... return connection
>>>
>>> s1 = get_database_session(create_sessionmaker(), DATA_URL_1)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> s2 = get_database_session(create_sessionmaker(), DATA_URL_1)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value - even though the _sessionmaker parameter was different
>>> # in both calls.
A cache_resource function's cache can be procedurally cleared:
>>> import streamlit as st
>>>
>>> @st.cache_resource
... def get_database_session(_sessionmaker, url):
... # Create a database connection object that points to the URL.
... return connection
>>>
>>> fetch_and_clean_data.clear(_sessionmaker, "https://streamlit.io/")
>>> # Clear the cached entry for the arguments provided.
>>>
>>> get_database_session.clear()
>>> # Clear all cached entries for this function.
To override the default hashing behavior, pass a custom hash function.
You can do that by mapping a type (e.g. ``Person``) to a hash
function (``str``) like this:
>>> import streamlit as st
>>> from pydantic import BaseModel
>>>
>>> class Person(BaseModel):
... name: str
>>>
>>> @st.cache_resource(hash_funcs={Person: str})
... def get_person_name(person: Person):
... return person.name
Alternatively, you can map the type's fully-qualified name
(e.g. ``"__main__.Person"``) to the hash function instead:
>>> import streamlit as st
>>> from pydantic import BaseModel
>>>
>>> class Person(BaseModel):
... name: str
>>>
>>> @st.cache_resource(hash_funcs={"__main__.Person": str})
... def get_person_name(person: Person):
... return person.name
"""
if experimental_allow_widgets:
show_widget_replay_deprecation("cache_resource")
# Support passing the params via function decorator, e.g.
# @st.cache_resource(show_spinner=False)
if func is None:
return lambda f: make_cached_func_wrapper(
CachedResourceFuncInfo(
func=f,
show_spinner=show_spinner,
max_entries=max_entries,
ttl=ttl,
validate=validate,
hash_funcs=hash_funcs,
)
)
return make_cached_func_wrapper(
CachedResourceFuncInfo(
func=cast("types.FunctionType", func),
show_spinner=show_spinner,
max_entries=max_entries,
ttl=ttl,
validate=validate,
hash_funcs=hash_funcs,
)
)
@gather_metrics("clear_resource_caches")
def clear(self) -> None:
"""Clear all cache_resource caches."""
_resource_caches.clear_all()
class ResourceCache(Cache):
"""Manages cached values for a single st.cache_resource function."""
def __init__(
self,
key: str,
max_entries: float,
ttl_seconds: float,
validate: ValidateFunc | None,
display_name: str,
):
super().__init__()
self.key = key
self.display_name = display_name
self._mem_cache: TTLCache[str, CachedResult] = TTLCache(
maxsize=max_entries, ttl=ttl_seconds, timer=cache_utils.TTLCACHE_TIMER
)
self._mem_cache_lock = threading.Lock()
self.validate = validate
@property
def max_entries(self) -> float:
return self._mem_cache.maxsize
@property
def ttl_seconds(self) -> float:
return self._mem_cache.ttl
def read_result(self, key: str) -> CachedResult:
"""Read a value and associated messages from the cache.
Raise `CacheKeyNotFoundError` if the value doesn't exist.
"""
with self._mem_cache_lock:
if key not in self._mem_cache:
# key does not exist in cache.
raise CacheKeyNotFoundError()
result = self._mem_cache[key]
if self.validate is not None and not self.validate(result.value):
# Validate failed: delete the entry and raise an error.
del self._mem_cache[key]
raise CacheKeyNotFoundError()
return result
@gather_metrics("_cache_resource_object")
def write_result(self, key: str, value: Any, messages: list[MsgData]) -> None:
"""Write a value and associated messages to the cache."""
main_id = st._main.id
sidebar_id = st.sidebar.id
with self._mem_cache_lock:
self._mem_cache[key] = CachedResult(value, messages, main_id, sidebar_id)
def _clear(self, key: str | None = None) -> None:
with self._mem_cache_lock:
if key is None:
self._mem_cache.clear()
elif key in self._mem_cache:
del self._mem_cache[key]
def get_stats(self) -> list[CacheStat]:
# Shallow clone our cache. Computing item sizes is potentially
# expensive, and we want to minimize the time we spend holding
# the lock.
with self._mem_cache_lock:
cache_entries = list(self._mem_cache.values())
# Lazy-load vendored package to prevent import of numpy
from streamlit.vendor.pympler.asizeof import asizeof
return [
CacheStat(
category_name="st_cache_resource",
cache_name=self.display_name,
byte_length=asizeof(entry),
)
for entry in cache_entries
]

View File

@@ -0,0 +1,33 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import enum
class CacheType(enum.Enum):
"""The function cache types we implement."""
DATA = "DATA"
RESOURCE = "RESOURCE"
def get_decorator_api_name(cache_type: CacheType) -> str:
"""Return the name of the public decorator API for the given CacheType."""
if cache_type is CacheType.DATA:
return "cache_data"
if cache_type is CacheType.RESOURCE:
return "cache_resource"
raise RuntimeError(f"Unrecognized CacheType '{cache_type}'")

View File

@@ -0,0 +1,525 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common cache logic shared by st.cache_data and st.cache_resource."""
from __future__ import annotations
import contextlib
import functools
import hashlib
import inspect
import threading
import time
from abc import abstractmethod
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Callable, Final
from streamlit import type_util
from streamlit.dataframe_util import is_unevaluated_data_object
from streamlit.elements.spinner import spinner
from streamlit.logger import get_logger
from streamlit.runtime.caching.cache_errors import (
CacheError,
CacheKeyNotFoundError,
UnevaluatedDataFrameError,
UnhashableParamError,
UnhashableTypeError,
UnserializableReturnValueError,
get_cached_func_name_md,
)
from streamlit.runtime.caching.cached_message_replay import (
CachedMessageReplayContext,
CachedResult,
MsgData,
replay_cached_messages,
)
from streamlit.runtime.caching.hashing import HashFuncsDict, update_hash
from streamlit.runtime.scriptrunner_utils.script_run_context import (
in_cached_function,
)
if TYPE_CHECKING:
from types import FunctionType
from streamlit.runtime.caching.cache_type import CacheType
_LOGGER: Final = get_logger(__name__)
# The timer function we use with TTLCache. This is the default timer func, but
# is exposed here as a constant so that it can be patched in unit tests.
TTLCACHE_TIMER = time.monotonic
class Cache:
"""Function cache interface. Caches persist across script runs."""
def __init__(self):
self._value_locks: dict[str, threading.Lock] = defaultdict(threading.Lock)
self._value_locks_lock = threading.Lock()
@abstractmethod
def read_result(self, value_key: str) -> CachedResult:
"""Read a value and associated messages from the cache.
Raises
------
CacheKeyNotFoundError
Raised if value_key is not in the cache.
"""
raise NotImplementedError
@abstractmethod
def write_result(self, value_key: str, value: Any, messages: list[MsgData]) -> None:
"""Write a value and associated messages to the cache, overwriting any existing
result that uses the value_key.
"""
# We *could* `del self._value_locks[value_key]` here, since nobody will be taking
# a compute_value_lock for this value_key after the result is written.
raise NotImplementedError
def compute_value_lock(self, value_key: str) -> threading.Lock:
"""Return the lock that should be held while computing a new cached value.
In a popular app with a cache that hasn't been pre-warmed, many sessions may try
to access a not-yet-cached value simultaneously. We use a lock to ensure that
only one of those sessions computes the value, and the others block until
the value is computed.
"""
with self._value_locks_lock:
return self._value_locks[value_key]
def clear(self, key: str | None = None):
"""Clear values from this cache.
If no argument is passed, all items are cleared from the cache.
A key can be passed to clear that key from the cache only.
"""
with self._value_locks_lock:
if not key:
self._value_locks.clear()
elif key in self._value_locks:
del self._value_locks[key]
self._clear(key=key)
@abstractmethod
def _clear(self, key: str | None = None) -> None:
"""Subclasses must implement this to perform cache-clearing logic."""
raise NotImplementedError
class CachedFuncInfo:
"""Encapsulates data for a cached function instance.
CachedFuncInfo instances are scoped to a single script run - they're not
persistent.
"""
def __init__(
self,
func: FunctionType,
show_spinner: bool | str,
hash_funcs: HashFuncsDict | None,
):
self.func = func
self.show_spinner = show_spinner
self.hash_funcs = hash_funcs
@property
def cache_type(self) -> CacheType:
raise NotImplementedError
@property
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
raise NotImplementedError
def get_function_cache(self, function_key: str) -> Cache:
"""Get or create the function cache for the given key."""
raise NotImplementedError
def make_cached_func_wrapper(info: CachedFuncInfo) -> Callable[..., Any]:
"""Create a callable wrapper around a CachedFunctionInfo.
Calling the wrapper will return the cached value if it's already been
computed, and will call the underlying function to compute and cache the
value otherwise.
The wrapper also has a `clear` function that can be called to clear
some or all of the wrapper's cached values.
"""
cached_func = CachedFunc(info)
return functools.update_wrapper(cached_func, info.func)
class BoundCachedFunc:
"""A wrapper around a CachedFunc that binds it to a specific instance in case of
decorated function is a class method.
"""
def __init__(self, cached_func: CachedFunc, instance: Any):
self._cached_func = cached_func
self._instance = instance
def __call__(self, *args, **kwargs) -> Any:
return self._cached_func(self._instance, *args, **kwargs)
def __repr__(self):
return f"<BoundCachedFunc: {self._cached_func._info.func} of {self._instance}>"
def clear(self, *args, **kwargs):
if args or kwargs:
# The instance is required as first parameter to allow
# args to be correctly resolved to the parameter names:
self._cached_func.clear(self._instance, *args, **kwargs)
else:
# if no args/kwargs are specified, we just want to clear the
# entire cache of this method:
self._cached_func.clear()
class CachedFunc:
def __init__(self, info: CachedFuncInfo):
self._info = info
self._function_key = _make_function_key(info.cache_type, info.func)
def __repr__(self):
return f"<CachedFunc: {self._info.func}>"
def __get__(self, instance, owner=None):
"""CachedFunc implements descriptor protocol to support cache methods."""
if instance is None:
return self
return functools.update_wrapper(BoundCachedFunc(self, instance), self)
def __call__(self, *args, **kwargs) -> Any:
"""The wrapper. We'll only call our underlying function on a cache miss."""
spinner_message: str | None = None
if isinstance(self._info.show_spinner, str):
spinner_message = self._info.show_spinner
elif self._info.show_spinner is True:
name = self._info.func.__qualname__
if len(args) == 0 and len(kwargs) == 0:
spinner_message = f"Running `{name}()`."
else:
spinner_message = f"Running `{name}(...)`."
return self._get_or_create_cached_value(args, kwargs, spinner_message)
def _get_or_create_cached_value(
self,
func_args: tuple[Any, ...],
func_kwargs: dict[str, Any],
spinner_message: str | None = None,
) -> Any:
# Retrieve the function's cache object. We must do this "just-in-time"
# (as opposed to in the constructor), because caches can be invalidated
# at any time.
cache = self._info.get_function_cache(self._function_key)
# Generate the key for the cached value. This is based on the
# arguments passed to the function.
value_key = _make_value_key(
cache_type=self._info.cache_type,
func=self._info.func,
func_args=func_args,
func_kwargs=func_kwargs,
hash_funcs=self._info.hash_funcs,
)
with contextlib.suppress(CacheKeyNotFoundError):
cached_result = cache.read_result(value_key)
return self._handle_cache_hit(cached_result)
# only show spinner if there is a message to show and always only for the
# outermost cache function if cache functions are nested, because the outermost
# function has to wait for the inner functions anyways. This avoids surprising
# users with slowdowned apps in case the inner functions are called very often,
# which would lead to a ton of (empty/spinner) proto messages that will make the
# app slow (see https://github.com/streamlit/streamlit/issues/9951). This is
# basically like auto-setting "show_spinner=False" on the @st.cache decorators
# on behalf of the user.
is_nested_cache_function = in_cached_function.get()
spinner_or_no_context = (
spinner(spinner_message, _cache=True)
if spinner_message is not None and not is_nested_cache_function
else contextlib.nullcontext()
)
with spinner_or_no_context:
return self._handle_cache_miss(cache, value_key, func_args, func_kwargs)
def _handle_cache_hit(self, result: CachedResult) -> Any:
"""Handle a cache hit: replay the result's cached messages, and return its
value.
"""
replay_cached_messages(
result,
self._info.cache_type,
self._info.func,
)
return result.value
def _handle_cache_miss(
self,
cache: Cache,
value_key: str,
func_args: tuple[Any, ...],
func_kwargs: dict[str, Any],
) -> Any:
"""Handle a cache miss: compute a new cached value, write it back to the cache,
and return that newly-computed value.
"""
# Implementation notes:
# - We take a "compute_value_lock" before computing our value. This ensures that
# multiple sessions don't try to compute the same value simultaneously.
#
# - We use a different lock for each value_key, as opposed to a single lock for
# the entire cache, so that unrelated value computations don't block on each other.
#
# - When retrieving a cache entry that may not yet exist, we use a "double-checked locking"
# strategy: first we try to retrieve the cache entry without taking a value lock. (This
# happens in `_get_or_create_cached_value()`.) If that fails because the value hasn't
# been computed yet, we take the value lock and then immediately try to retrieve cache entry
# *again*, while holding the lock. If the cache entry exists at this point, it means that
# another thread computed the value before us.
#
# This means that the happy path ("cache entry exists") is a wee bit faster because
# no lock is acquired. But the unhappy path ("cache entry needs to be recomputed") is
# a wee bit slower, because we do two lookups for the entry.
with cache.compute_value_lock(value_key):
# We've acquired the lock - but another thread may have acquired it first
# and already computed the value. So we need to test for a cache hit again,
# before computing.
try:
cached_result = cache.read_result(value_key)
# Another thread computed the value before us. Early exit!
return self._handle_cache_hit(cached_result)
except CacheKeyNotFoundError:
# No cache hit -> we will call the cached function
# below.
pass
# We acquired the lock before any other thread. Compute the value!
with self._info.cached_message_replay_ctx.calling_cached_function(
self._info.func
):
computed_value = self._info.func(*func_args, **func_kwargs)
# We've computed our value, and now we need to write it back to the cache
# along with any "replay messages" that were generated during value computation.
messages = self._info.cached_message_replay_ctx._most_recent_messages
try:
cache.write_result(value_key, computed_value, messages)
return computed_value
except (CacheError, RuntimeError) as ex:
# An exception was thrown while we tried to write to the cache. Report
# it to the user. (We catch `RuntimeError` here because it will be
# raised by Apache Spark if we do not collect dataframe before
# using `st.cache_data`.)
if is_unevaluated_data_object(computed_value):
# If the returned value is an unevaluated dataframe, raise an error.
# Unevaluated dataframes are not yet in the local memory, which also
# means they cannot be properly cached (serialized).
raise UnevaluatedDataFrameError(
f"The function {get_cached_func_name_md(self._info.func)} is "
"decorated with `st.cache_data` but it returns an unevaluated "
f"data object of type `{type_util.get_fqn_type(computed_value)}`. "
"Please convert the object to a serializable format "
"(e.g. Pandas DataFrame) before returning it, so "
"`st.cache_data` can serialize and cache it."
) from ex
raise UnserializableReturnValueError(
return_value=computed_value, func=self._info.func
)
def clear(self, *args, **kwargs):
"""Clear the cached function's associated cache.
If no arguments are passed, Streamlit will clear all values cached for
the function. If arguments are passed, Streamlit will clear the cached
value for these arguments only.
Parameters
----------
*args: Any
Arguments of the cached functions.
**kwargs: Any
Keyword arguments of the cached function.
Example
-------
>>> import streamlit as st
>>> import time
>>>
>>> @st.cache_data
>>> def foo(bar):
>>> time.sleep(2)
>>> st.write(f"Executed foo({bar}).")
>>> return bar
>>>
>>> if st.button("Clear all cached values for `foo`", on_click=foo.clear):
>>> foo.clear()
>>>
>>> if st.button("Clear the cached value of `foo(1)`"):
>>> foo.clear(1)
>>>
>>> foo(1)
>>> foo(2)
"""
cache = self._info.get_function_cache(self._function_key)
if args or kwargs:
key = _make_value_key(
cache_type=self._info.cache_type,
func=self._info.func,
func_args=args,
func_kwargs=kwargs,
hash_funcs=self._info.hash_funcs,
)
else:
key = None
cache.clear(key=key)
def _make_value_key(
cache_type: CacheType,
func: FunctionType,
func_args: tuple[Any, ...],
func_kwargs: dict[str, Any],
hash_funcs: HashFuncsDict | None,
) -> str:
"""Create the key for a value within a cache.
This key is generated from the function's arguments. All arguments
will be hashed, except for those named with a leading "_".
Raises
------
StreamlitAPIException
Raised (with a nicely-formatted explanation message) if we encounter
an un-hashable arg.
"""
# Create a (name, value) list of all *args and **kwargs passed to the
# function.
arg_pairs: list[tuple[str | None, Any]] = []
for arg_idx in range(len(func_args)):
arg_name = _get_positional_arg_name(func, arg_idx)
arg_pairs.append((arg_name, func_args[arg_idx]))
for kw_name, kw_val in func_kwargs.items():
# **kwargs ordering is preserved, per PEP 468
# https://www.python.org/dev/peps/pep-0468/, so this iteration is
# deterministic.
arg_pairs.append((kw_name, kw_val))
# Create the hash from each arg value, except for those args whose name
# starts with "_". (Underscore-prefixed args are deliberately excluded from
# hashing.)
args_hasher = hashlib.new("md5", usedforsecurity=False)
for arg_name, arg_value in arg_pairs:
if arg_name is not None and arg_name.startswith("_"):
_LOGGER.debug("Not hashing %s because it starts with _", arg_name)
continue
try:
update_hash(
arg_name,
hasher=args_hasher,
cache_type=cache_type,
hash_source=func,
)
# we call update_hash twice here, first time for `arg_name`
# without `hash_funcs`, and second time for `arg_value` with hash_funcs
# to evaluate user defined `hash_funcs` only for computing `arg_value` hash.
update_hash(
arg_value,
hasher=args_hasher,
cache_type=cache_type,
hash_funcs=hash_funcs,
hash_source=func,
)
except UnhashableTypeError as exc:
raise UnhashableParamError(cache_type, func, arg_name, arg_value, exc)
value_key = args_hasher.hexdigest()
_LOGGER.debug("Cache key: %s", value_key)
return value_key
def _make_function_key(cache_type: CacheType, func: FunctionType) -> str:
"""Create the unique key for a function's cache.
A function's key is stable across reruns of the app, and changes when
the function's source code changes.
"""
func_hasher = hashlib.new("md5", usedforsecurity=False)
# Include the function's __module__ and __qualname__ strings in the hash.
# This means that two identical functions in different modules
# will not share a hash; it also means that two identical *nested*
# functions in the same module will not share a hash.
update_hash(
(func.__module__, func.__qualname__),
hasher=func_hasher,
cache_type=cache_type,
hash_source=func,
)
# Include the function's source code in its hash. If the source code can't
# be retrieved, fall back to the function's bytecode instead.
source_code: str | bytes
try:
source_code = inspect.getsource(func)
except (OSError, TypeError) as ex:
_LOGGER.debug(
"Failed to retrieve function's source code when building its key; "
"falling back to bytecode.",
exc_info=ex,
)
source_code = func.__code__.co_code
update_hash(
source_code, hasher=func_hasher, cache_type=cache_type, hash_source=func
)
return func_hasher.hexdigest()
def _get_positional_arg_name(func: FunctionType, arg_index: int) -> str | None:
"""Return the name of a function's positional argument.
If arg_index is out of range, or refers to a parameter that is not a
named positional argument (e.g. an *args, **kwargs, or keyword-only param),
return None instead.
"""
if arg_index < 0:
return None
params: list[inspect.Parameter] = list(inspect.signature(func).parameters.values())
if arg_index >= len(params):
return None
if params[arg_index].kind in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.POSITIONAL_ONLY,
):
return params[arg_index].name
return None

View File

@@ -0,0 +1,290 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import threading
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Union
import streamlit as st
from streamlit import runtime, util
from streamlit.deprecation_util import show_deprecation_warning
from streamlit.runtime.caching.cache_errors import CacheReplayClosureError
from streamlit.runtime.scriptrunner_utils.script_run_context import (
in_cached_function,
)
if TYPE_CHECKING:
from collections.abc import Iterator
from types import FunctionType
from google.protobuf.message import Message
from streamlit.delta_generator import DeltaGenerator
from streamlit.proto.Block_pb2 import Block
from streamlit.runtime.caching.cache_type import CacheType
@dataclass(frozen=True)
class MediaMsgData:
media: bytes | str
mimetype: str
media_id: str
@dataclass(frozen=True)
class ElementMsgData:
"""An element's message and related metadata for
replaying that element's function call.
media_data is filled in iff this is a media element (image, audio, video).
"""
delta_type: str
message: Message
id_of_dg_called_on: str
returned_dgs_id: str
media_data: list[MediaMsgData] | None = None
@dataclass(frozen=True)
class BlockMsgData:
message: Block
id_of_dg_called_on: str
returned_dgs_id: str
MsgData = Union[ElementMsgData, BlockMsgData]
@dataclass
class CachedResult:
"""The full results of calling a cache-decorated function, enough to
replay the st functions called while executing it.
"""
value: Any
messages: list[MsgData]
main_id: str
sidebar_id: str
"""
Note [DeltaGenerator method invocation]
There are two top level DG instances defined for all apps:
`main`, which is for putting elements in the main part of the app
`sidebar`, for the sidebar
There are 3 different ways an st function can be invoked:
1. Implicitly on the main DG instance (plain `st.foo` calls)
2. Implicitly in an active contextmanager block (`st.foo` within a `with st.container` context)
3. Explicitly on a DG instance (`st.sidebar.foo`, `my_column_1.foo`)
To simplify replaying messages from a cached function result, we convert all of these
to explicit invocations. How they get rewritten depends on if the invocation was
implicit vs explicit, and if the target DG has been seen/produced during replay.
Implicit invocation on a known DG -> Explicit invocation on that DG
Implicit invocation on an unknown DG -> Rewrite as explicit invocation on main
with st.container():
my_cache_decorated_function()
This is situation 2 above, and the DG is a block entirely outside our function call,
so we interpret it as "put this element in the enclosing contextmanager block"
(or main if there isn't one), which is achieved by invoking on main.
Explicit invocation on a known DG -> No change needed
Explicit invocation on an unknown DG -> Raise an error
We have no way to identify the target DG, and it may not even be present in the
current script run, so the least surprising thing to do is raise an error.
"""
class CachedMessageReplayContext(threading.local):
"""A utility for storing messages generated by `st` commands called inside
a cached function.
Data is stored in a thread-local object, so it's safe to use an instance
of this class across multiple threads.
"""
def __init__(self, cache_type: CacheType):
self._cached_message_stack: list[list[MsgData]] = []
self._seen_dg_stack: list[set[str]] = []
self._most_recent_messages: list[MsgData] = []
self._media_data: list[MediaMsgData] = []
self._cache_type = cache_type
def __repr__(self) -> str:
return util.repr_(self)
@contextlib.contextmanager
def calling_cached_function(self, func: FunctionType) -> Iterator[None]:
"""Context manager that should wrap the invocation of a cached function.
It allows us to track any `st.foo` messages that are generated from inside the
function for playback during cache retrieval.
"""
self._cached_message_stack.append([])
self._seen_dg_stack.append(set())
nested_call = False
if in_cached_function.get():
nested_call = True
# If we're in a cached function. To disallow usage of widget-like element,
# we need to set the in_cached_function to true for this cached function run
# to prevent widget usage (triggers a warning).
in_cached_function.set(True)
try:
yield
finally:
self._most_recent_messages = self._cached_message_stack.pop()
self._seen_dg_stack.pop()
if not nested_call:
# Reset the in_cached_function flag. But only if this
# is not nested inside a cached function that disallows widget usage.
in_cached_function.set(False)
def save_element_message(
self,
delta_type: str,
element_proto: Message,
invoked_dg_id: str,
used_dg_id: str,
returned_dg_id: str,
) -> None:
"""Record the element protobuf as having been produced during any currently
executing cached functions, so they can be replayed any time the function's
execution is skipped because they're in the cache.
"""
if not runtime.exists():
return
if len(self._cached_message_stack) >= 1:
id_to_save = self.select_dg_to_save(invoked_dg_id, used_dg_id)
media_data = self._media_data
element_msg_data = ElementMsgData(
delta_type,
element_proto,
id_to_save,
returned_dg_id,
media_data,
)
for msgs in self._cached_message_stack:
msgs.append(element_msg_data)
# Reset instance state, now that it has been used for the
# associated element.
self._media_data = []
for s in self._seen_dg_stack:
s.add(returned_dg_id)
def save_block_message(
self,
block_proto: Block,
invoked_dg_id: str,
used_dg_id: str,
returned_dg_id: str,
) -> None:
id_to_save = self.select_dg_to_save(invoked_dg_id, used_dg_id)
for msgs in self._cached_message_stack:
msgs.append(BlockMsgData(block_proto, id_to_save, returned_dg_id))
for s in self._seen_dg_stack:
s.add(returned_dg_id)
def select_dg_to_save(self, invoked_id: str, acting_on_id: str) -> str:
"""Select the id of the DG that this message should be invoked on
during message replay.
See Note [DeltaGenerator method invocation]
invoked_id is the DG the st function was called on, usually `st._main`.
acting_on_id is the DG the st function ultimately runs on, which may be different
if the invoked DG delegated to another one because it was in a `with` block.
"""
if len(self._seen_dg_stack) > 0 and acting_on_id in self._seen_dg_stack[-1]:
return acting_on_id
else:
return invoked_id
def save_image_data(
self, image_data: bytes | str, mimetype: str, image_id: str
) -> None:
self._media_data.append(MediaMsgData(image_data, mimetype, image_id))
def replay_cached_messages(
result: CachedResult, cache_type: CacheType, cached_func: FunctionType
) -> None:
"""Replay the st element function calls that happened when executing a
cache-decorated function.
When a cache function is executed, we record the element and block messages
produced, and use those to reproduce the DeltaGenerator calls, so the elements
will appear in the web app even when execution of the function is skipped
because the result was cached.
To make this work, for each st function call we record an identifier for the
DG it was effectively called on (see Note [DeltaGenerator method invocation]).
We also record the identifier for each DG returned by an st function call, if
it returns one. Then, for each recorded message, we get the current DG instance
corresponding to the DG the message was originally called on, and enqueue the
message using that, recording any new DGs produced in case a later st function
call is on one of them.
"""
from streamlit.delta_generator import DeltaGenerator
# Maps originally recorded dg ids to this script run's version of that dg
returned_dgs: dict[str, DeltaGenerator] = {
result.main_id: st._main,
result.sidebar_id: st.sidebar,
}
try:
for msg in result.messages:
if isinstance(msg, ElementMsgData):
if msg.media_data is not None:
for data in msg.media_data:
runtime.get_instance().media_file_mgr.add(
data.media, data.mimetype, data.media_id
)
dg = returned_dgs[msg.id_of_dg_called_on]
maybe_dg = dg._enqueue(msg.delta_type, msg.message)
if isinstance(maybe_dg, DeltaGenerator):
returned_dgs[msg.returned_dgs_id] = maybe_dg
elif isinstance(msg, BlockMsgData):
dg = returned_dgs[msg.id_of_dg_called_on]
new_dg = dg._block(msg.message)
returned_dgs[msg.returned_dgs_id] = new_dg
except KeyError as ex:
raise CacheReplayClosureError(cache_type, cached_func) from ex
def show_widget_replay_deprecation(
decorator: Literal["cache_data", "cache_resource"],
) -> None:
show_deprecation_warning(
"The cached widget replay feature was removed in 1.38. The "
"`experimental_allow_widgets` parameter will also be removed "
"in a future release. Please remove the `experimental_allow_widgets` parameter "
f"from the `@st.{decorator}` decorator and move all widget commands outside of "
"cached functions.\n\nTo speed up your app, we recommend moving your widgets "
"into fragments. Find out more about fragments in "
"[our docs](https://docs.streamlit.io/develop/api-reference/execution-flow/st.fragment). "
"\n\nIf you have a specific use-case that requires the "
"`experimental_allow_widgets` functionality, please tell us via an "
"[issue on Github](https://github.com/streamlit/streamlit/issues)."
)

View File

@@ -0,0 +1,637 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hashing for st.cache_data and st.cache_resource."""
from __future__ import annotations
import collections
import collections.abc
import dataclasses
import datetime
import functools
import hashlib
import inspect
import io
import os
import pickle
import sys
import tempfile
import threading
import uuid
import weakref
from enum import Enum
from re import Pattern
from types import MappingProxyType
from typing import Any, Callable, Final, Union, cast
from typing_extensions import TypeAlias
from streamlit import logger, type_util, util
from streamlit.errors import StreamlitAPIException
from streamlit.runtime.caching.cache_errors import UnhashableTypeError
from streamlit.runtime.caching.cache_type import CacheType
from streamlit.runtime.uploaded_file_manager import UploadedFile
_LOGGER: Final = logger.get_logger(__name__)
# If a dataframe has more than this many rows, we consider it large and hash a sample.
_PANDAS_ROWS_LARGE: Final = 50_000
_PANDAS_SAMPLE_SIZE: Final = 10_000
# Similar to dataframes, we also sample large numpy arrays.
_NP_SIZE_LARGE: Final = 500_000
_NP_SAMPLE_SIZE: Final = 100_000
HashFuncsDict: TypeAlias = dict[Union[str, type[Any]], Callable[[Any], Any]]
# Arbitrary item to denote where we found a cycle in a hashed object.
# This allows us to hash self-referencing lists, dictionaries, etc.
_CYCLE_PLACEHOLDER: Final = (
b"streamlit-57R34ML17-hesamagicalponyflyingthroughthesky-CYCLE"
)
class UserHashError(StreamlitAPIException):
def __init__(
self,
orig_exc,
object_to_hash,
hash_func,
cache_type: CacheType | None = None,
):
self.alternate_name = type(orig_exc).__name__
self.hash_func = hash_func
self.cache_type = cache_type
msg = self._get_message_from_func(orig_exc, object_to_hash)
super().__init__(msg)
self.with_traceback(orig_exc.__traceback__)
def _get_message_from_func(self, orig_exc, cached_func):
args = self._get_error_message_args(orig_exc, cached_func)
return (
"""
%(orig_exception_desc)s
This error is likely due to a bug in %(hash_func_name)s, which is a
user-defined hash function that was passed into the `%(cache_primitive)s` decorator of
%(object_desc)s.
%(hash_func_name)s failed when hashing an object of type
`%(failed_obj_type_str)s`. If you don't know where that object is coming from,
try looking at the hash chain below for an object that you do recognize, then
pass that to `hash_funcs` instead:
```
%(hash_stack)s
```
If you think this is actually a Streamlit bug, please
[file a bug report here](https://github.com/streamlit/streamlit/issues/new/choose).
"""
% args
).strip("\n")
def _get_error_message_args(
self,
orig_exc: BaseException,
failed_obj: Any,
) -> dict[str, Any]:
hash_source = hash_stacks.current.hash_source
failed_obj_type_str = type_util.get_fqn_type(failed_obj)
if hash_source is None:
object_desc = "something"
else:
if hasattr(hash_source, "__name__"):
object_desc = f"`{hash_source.__name__}()`"
else:
object_desc = "a function"
decorator_name = ""
if self.cache_type is CacheType.RESOURCE:
decorator_name = "@st.cache_resource"
elif self.cache_type is CacheType.DATA:
decorator_name = "@st.cache_data"
if hasattr(self.hash_func, "__name__"):
hash_func_name = f"`{self.hash_func.__name__}()`"
else:
hash_func_name = "a function"
return {
"orig_exception_desc": str(orig_exc),
"failed_obj_type_str": failed_obj_type_str,
"hash_stack": hash_stacks.current.pretty_print(),
"object_desc": object_desc,
"cache_primitive": decorator_name,
"hash_func_name": hash_func_name,
}
def update_hash(
val: Any,
hasher,
cache_type: CacheType,
hash_source: Callable[..., Any] | None = None,
hash_funcs: HashFuncsDict | None = None,
) -> None:
"""Updates a hashlib hasher with the hash of val.
This is the main entrypoint to hashing.py.
"""
hash_stacks.current.hash_source = hash_source
ch = _CacheFuncHasher(cache_type, hash_funcs)
ch.update(hasher, val)
class _HashStack:
"""Stack of what has been hashed, for debug and circular reference detection.
This internally keeps 1 stack per thread.
Internally, this stores the ID of pushed objects rather than the objects
themselves because otherwise the "in" operator inside __contains__ would
fail for objects that don't return a boolean for "==" operator. For
example, arr == 10 where arr is a NumPy array returns another NumPy array.
This causes the "in" to crash since it expects a boolean.
"""
def __init__(self):
self._stack: collections.OrderedDict[int, list[Any]] = collections.OrderedDict()
# A function that we decorate with streamlit cache
# primitive (st.cache_data or st.cache_resource).
self.hash_source: Callable[..., Any] | None = None
def __repr__(self) -> str:
return util.repr_(self)
def push(self, val: Any):
self._stack[id(val)] = val
def pop(self):
self._stack.popitem()
def __contains__(self, val: Any):
return id(val) in self._stack
def pretty_print(self) -> str:
def to_str(v: Any) -> str:
try:
return f"Object of type {type_util.get_fqn_type(v)}: {str(v)}"
except Exception:
return "<Unable to convert item to string>"
return "\n".join(to_str(x) for x in reversed(self._stack.values()))
class _HashStacks:
"""Stacks of what has been hashed, with at most 1 stack per thread."""
def __init__(self):
self._stacks: weakref.WeakKeyDictionary[threading.Thread, _HashStack] = (
weakref.WeakKeyDictionary()
)
def __repr__(self) -> str:
return util.repr_(self)
@property
def current(self) -> _HashStack:
current_thread = threading.current_thread()
stack = self._stacks.get(current_thread, None)
if stack is None:
stack = _HashStack()
self._stacks[current_thread] = stack
return stack
hash_stacks = _HashStacks()
def _int_to_bytes(i: int) -> bytes:
num_bytes = (i.bit_length() + 8) // 8
return i.to_bytes(num_bytes, "little", signed=True)
def _float_to_bytes(f: float) -> bytes:
# Lazy-load for performance reasons.
import struct
# Floats are 64bit in Python, so we need to use the "d" format.
return struct.pack("<d", f)
def _key(obj: Any | None) -> Any:
"""Return key for memoization."""
if obj is None:
return None
def is_simple(obj):
return (
isinstance(obj, bytes)
or isinstance(obj, bytearray)
or isinstance(obj, str)
or isinstance(obj, float)
or isinstance(obj, int)
or isinstance(obj, bool)
or isinstance(obj, uuid.UUID)
or obj is None
)
if is_simple(obj):
return obj
if isinstance(obj, tuple):
if all(map(is_simple, obj)):
return obj
if isinstance(obj, list):
if all(map(is_simple, obj)):
return ("__l", tuple(obj))
if inspect.isbuiltin(obj) or inspect.isroutine(obj) or inspect.iscode(obj):
return id(obj)
return NoResult
class _CacheFuncHasher:
"""A hasher that can hash objects with cycles."""
def __init__(self, cache_type: CacheType, hash_funcs: HashFuncsDict | None = None):
# Can't use types as the keys in the internal _hash_funcs because
# we always remove user-written modules from memory when rerunning a
# script in order to reload it and grab the latest code changes.
# (See LocalSourcesWatcher.py:on_file_changed) This causes
# the type object to refer to different underlying class instances each run,
# so type-based comparisons fail. To solve this, we use the types converted
# to fully-qualified strings as keys in our internal dict.
self._hash_funcs: HashFuncsDict
if hash_funcs:
self._hash_funcs = {
k if isinstance(k, str) else type_util.get_fqn(k): v
for k, v in hash_funcs.items()
}
else:
self._hash_funcs = {}
self._hashes: dict[Any, bytes] = {}
# The number of the bytes in the hash.
self.size = 0
self.cache_type = cache_type
def __repr__(self) -> str:
return util.repr_(self)
def to_bytes(self, obj: Any) -> bytes:
"""Add memoization to _to_bytes and protect against cycles in data structures."""
tname = type(obj).__qualname__.encode()
key = (tname, _key(obj))
# Memoize if possible.
if key[1] is not NoResult:
if key in self._hashes:
return self._hashes[key]
# Break recursive cycles.
if obj in hash_stacks.current:
return _CYCLE_PLACEHOLDER
hash_stacks.current.push(obj)
try:
# Hash the input
b = b"%s:%s" % (tname, self._to_bytes(obj))
# Hmmm... It's possible that the size calculation is wrong. When we
# call to_bytes inside _to_bytes things get double-counted.
self.size += sys.getsizeof(b)
if key[1] is not NoResult:
self._hashes[key] = b
finally:
# In case an UnhashableTypeError (or other) error is thrown, clean up the
# stack so we don't get false positives in future hashing calls
hash_stacks.current.pop()
return b
def update(self, hasher, obj: Any) -> None:
"""Update the provided hasher with the hash of an object."""
b = self.to_bytes(obj)
hasher.update(b)
def _to_bytes(self, obj: Any) -> bytes:
"""Hash objects to bytes, including code with dependencies.
Python's built in `hash` does not produce consistent results across
runs.
"""
h = hashlib.new("md5", usedforsecurity=False)
if type_util.is_type(obj, "unittest.mock.Mock") or type_util.is_type(
obj, "unittest.mock.MagicMock"
):
# Mock objects can appear to be infinitely
# deep, so we don't try to hash them at all.
return self.to_bytes(id(obj))
elif isinstance(obj, bytes) or isinstance(obj, bytearray):
return obj
elif type_util.get_fqn_type(obj) in self._hash_funcs:
# Escape hatch for unsupported objects
hash_func = self._hash_funcs[type_util.get_fqn_type(obj)]
try:
output = hash_func(obj)
except Exception as ex:
raise UserHashError(
ex, obj, hash_func=hash_func, cache_type=self.cache_type
) from ex
return self.to_bytes(output)
elif isinstance(obj, str):
return obj.encode()
elif isinstance(obj, float):
return _float_to_bytes(obj)
elif isinstance(obj, int):
return _int_to_bytes(obj)
elif isinstance(obj, uuid.UUID):
return obj.bytes
elif isinstance(obj, datetime.datetime):
return obj.isoformat().encode()
elif isinstance(obj, (list, tuple)):
for item in obj:
self.update(h, item)
return h.digest()
elif isinstance(obj, dict):
for item in obj.items():
self.update(h, item)
return h.digest()
elif obj is None:
return b"0"
elif obj is True:
return b"1"
elif obj is False:
return b"0"
elif not isinstance(obj, type) and dataclasses.is_dataclass(obj):
return self.to_bytes(dataclasses.asdict(obj))
elif isinstance(obj, Enum):
return str(obj).encode()
elif type_util.is_type(obj, "pandas.core.series.Series"):
import pandas as pd
obj = cast("pd.Series", obj)
self.update(h, obj.size)
self.update(h, obj.dtype.name)
if len(obj) >= _PANDAS_ROWS_LARGE:
obj = obj.sample(n=_PANDAS_SAMPLE_SIZE, random_state=0)
try:
self.update(h, pd.util.hash_pandas_object(obj).to_numpy().tobytes())
return h.digest()
except TypeError:
_LOGGER.warning(
"Pandas Series hash failed. Falling back to pickling the object.",
exc_info=True,
)
# Use pickle if pandas cannot hash the object for example if
# it contains unhashable objects.
return b"%s" % pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
elif type_util.is_type(obj, "pandas.core.frame.DataFrame"):
import pandas as pd
obj = cast("pd.DataFrame", obj)
self.update(h, obj.shape)
if len(obj) >= _PANDAS_ROWS_LARGE:
obj = obj.sample(n=_PANDAS_SAMPLE_SIZE, random_state=0)
try:
column_hash_bytes = self.to_bytes(
pd.util.hash_pandas_object(obj.dtypes)
)
self.update(h, column_hash_bytes)
values_hash_bytes = self.to_bytes(pd.util.hash_pandas_object(obj))
self.update(h, values_hash_bytes)
return h.digest()
except TypeError:
_LOGGER.warning(
"Pandas DataFrame hash failed. Falling back to pickling the object.",
exc_info=True,
)
# Use pickle if pandas cannot hash the object for example if
# it contains unhashable objects.
return b"%s" % pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
elif type_util.is_type(obj, "polars.series.series.Series"):
import polars as pl # type: ignore[import-not-found]
obj = cast("pl.Series", obj)
self.update(h, str(obj.dtype).encode())
self.update(h, obj.shape)
if len(obj) >= _PANDAS_ROWS_LARGE:
obj = obj.sample(n=_PANDAS_SAMPLE_SIZE, seed=0)
try:
self.update(h, obj.hash(seed=0).to_arrow().to_string().encode())
return h.digest()
except TypeError:
_LOGGER.warning(
"Polars Series hash failed. Falling back to pickling the object.",
exc_info=True,
)
# Use pickle if polars cannot hash the object for example if
# it contains unhashable objects.
return b"%s" % pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
elif type_util.is_type(obj, "polars.dataframe.frame.DataFrame"):
import polars as pl # noqa: TC002
obj = cast("pl.DataFrame", obj)
self.update(h, obj.shape)
if len(obj) >= _PANDAS_ROWS_LARGE:
obj = obj.sample(n=_PANDAS_SAMPLE_SIZE, seed=0)
try:
for c, t in obj.schema.items():
self.update(h, c.encode())
self.update(h, str(t).encode())
values_hash_bytes = (
obj.hash_rows(seed=0).hash(seed=0).to_arrow().to_string().encode()
)
self.update(h, values_hash_bytes)
return h.digest()
except TypeError:
_LOGGER.warning(
"Polars DataFrame hash failed. Falling back to pickling the object.",
exc_info=True,
)
# Use pickle if polars cannot hash the object for example if
# it contains unhashable objects.
return b"%s" % pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
elif type_util.is_type(obj, "numpy.ndarray"):
import numpy as np
# write cast type as string to make it work with our Python 3.8 tests
# - can be removed once we sunset support for Python 3.8
obj = cast("np.ndarray[Any, Any]", obj)
self.update(h, obj.shape)
self.update(h, str(obj.dtype))
if obj.size >= _NP_SIZE_LARGE:
import numpy as np
state = np.random.RandomState(0)
obj = state.choice(obj.flat, size=_NP_SAMPLE_SIZE)
self.update(h, obj.tobytes())
return h.digest()
elif type_util.is_type(obj, "PIL.Image.Image"):
import numpy as np
from PIL.Image import Image # noqa: TC002
obj = cast("Image", obj)
# we don't just hash the results of obj.tobytes() because we want to use
# the sampling logic for numpy data
np_array = np.frombuffer(obj.tobytes(), dtype="uint8")
return self.to_bytes(np_array)
elif inspect.isbuiltin(obj):
return bytes(obj.__name__.encode())
elif isinstance(obj, MappingProxyType) or isinstance(
obj, collections.abc.ItemsView
):
return self.to_bytes(dict(obj))
elif type_util.is_type(obj, "builtins.getset_descriptor"):
return bytes(obj.__qualname__.encode())
elif isinstance(obj, UploadedFile):
# UploadedFile is a BytesIO (thus IOBase) but has a name.
# It does not have a timestamp so this must come before
# temporary files
self.update(h, obj.name)
self.update(h, obj.tell())
self.update(h, obj.getvalue())
return h.digest()
elif hasattr(obj, "name") and (
isinstance(obj, io.IOBase)
# Handle temporary files used during testing
or isinstance(obj, tempfile._TemporaryFileWrapper)
):
# Hash files as name + last modification date + offset.
# NB: we're using hasattr("name") to differentiate between
# on-disk and in-memory StringIO/BytesIO file representations.
# That means that this condition must come *before* the next
# condition, which just checks for StringIO/BytesIO.
obj_name = getattr(obj, "name", "wonthappen") # Just to appease MyPy.
self.update(h, obj_name)
self.update(h, os.path.getmtime(obj_name))
self.update(h, obj.tell())
return h.digest()
elif isinstance(obj, Pattern):
return self.to_bytes([obj.pattern, obj.flags])
elif isinstance(obj, io.StringIO) or isinstance(obj, io.BytesIO):
# Hash in-memory StringIO/BytesIO by their full contents
# and seek position.
self.update(h, obj.tell())
self.update(h, obj.getvalue())
return h.digest()
elif type_util.is_type(obj, "numpy.ufunc"):
# For numpy.remainder, this returns remainder.
return bytes(obj.__name__.encode())
elif inspect.ismodule(obj):
# TODO: Figure out how to best show this kind of warning to the
# user. In the meantime, show nothing. This scenario is too common,
# so the current warning is quite annoying...
# st.warning(('Streamlit does not support hashing modules. '
# 'We did not hash `%s`.') % obj.__name__)
# TODO: Hash more than just the name for internal modules.
return self.to_bytes(obj.__name__)
elif inspect.isclass(obj):
# TODO: Figure out how to best show this kind of warning to the
# user. In the meantime, show nothing. This scenario is too common,
# (e.g. in every "except" statement) so the current warning is
# quite annoying...
# st.warning(('Streamlit does not support hashing classes. '
# 'We did not hash `%s`.') % obj.__name__)
# TODO: Hash more than just the name of classes.
return self.to_bytes(obj.__name__)
elif isinstance(obj, functools.partial):
# The return value of functools.partial is not a plain function:
# it's a callable object that remembers the original function plus
# the values you pickled into it. So here we need to special-case it.
self.update(h, obj.args)
self.update(h, obj.func)
self.update(h, obj.keywords)
return h.digest()
else:
# As a last resort, hash the output of the object's __reduce__ method
try:
reduce_data = obj.__reduce__()
except Exception as ex:
raise UnhashableTypeError() from ex
for item in reduce_data:
self.update(h, item)
return h.digest()
class NoResult:
"""Placeholder class for return values when None is meaningful."""
pass

View File

@@ -0,0 +1,169 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A library of caching utilities."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, TypeVar
from streamlit import deprecation_util
from streamlit.runtime.caching import CACHE_DOCS_URL
from streamlit.runtime.metrics_util import gather_metrics
if TYPE_CHECKING:
from streamlit.runtime.caching.hashing import HashFuncsDict
# Type-annotate the decorator function.
# (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories)
F = TypeVar("F", bound=Callable[..., Any])
@gather_metrics("cache")
def cache(
func: F | None = None,
persist: bool = False,
allow_output_mutation: bool = False,
show_spinner: bool = True,
suppress_st_warning: bool = False,
hash_funcs: HashFuncsDict | None = None,
max_entries: int | None = None,
ttl: float | None = None,
):
"""Legacy caching decorator (deprecated).
Legacy caching with ``st.cache`` has been removed from Streamlit. This is
now an alias for ``st.cache_data`` and ``st.cache_resource``.
Parameters
----------
func : callable
The function to cache. Streamlit hashes the function's source code.
persist : bool
Whether to persist the cache on disk.
allow_output_mutation : bool
Whether to use ``st.cache_data`` or ``st.cache_resource``. If this is
``False`` (default), the arguments are passed to ``st.cache_data``. If
this is ``True``, the arguments are passed to ``st.cache_resource``.
show_spinner : bool
Enable the spinner. Default is ``True`` to show a spinner when there is
a "cache miss" and the cached data is being created.
suppress_st_warning : bool
This is not used.
hash_funcs : dict or None
Mapping of types or fully qualified names to hash functions. This is used to
override the behavior of the hasher inside Streamlit's caching mechanism: when
the hasher encounters an object, it will first check to see if its type matches
a key in this dict and, if so, will use the provided function to generate a hash
for it. See below for an example of how this can be used.
max_entries : int or None
The maximum number of entries to keep in the cache, or ``None``
for an unbounded cache. (When a new entry is added to a full cache,
the oldest cached entry will be removed.) The default is ``None``.
ttl : float or None
The maximum number of seconds to keep an entry in the cache, or
None if cache entries should not expire. The default is None.
Example
-------
>>> import streamlit as st
>>>
>>> @st.cache
... def fetch_and_clean_data(url):
... # Fetch data from URL here, and then clean it up.
... return data
>>>
>>> d1 = fetch_and_clean_data(DATA_URL_1)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> d2 = fetch_and_clean_data(DATA_URL_1)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value. This means that now the data in d1 is the same as in d2.
>>>
>>> d3 = fetch_and_clean_data(DATA_URL_2)
>>> # This is a different URL, so the function executes.
To set the ``persist`` parameter, use this command as follows:
>>> @st.cache(persist=True)
... def fetch_and_clean_data(url):
... # Fetch data from URL here, and then clean it up.
... return data
To disable hashing return values, set the ``allow_output_mutation`` parameter to
``True``:
>>> @st.cache(allow_output_mutation=True)
... def fetch_and_clean_data(url):
... # Fetch data from URL here, and then clean it up.
... return data
To override the default hashing behavior, pass a custom hash function.
You can do that by mapping a type (e.g. ``MongoClient``) to a hash function (``id``)
like this:
>>> @st.cache(hash_funcs={MongoClient: id})
... def connect_to_database(url):
... return MongoClient(url)
Alternatively, you can map the type's fully-qualified name
(e.g. ``"pymongo.mongo_client.MongoClient"``) to the hash function instead:
>>> @st.cache(hash_funcs={"pymongo.mongo_client.MongoClient": id})
... def connect_to_database(url):
... return MongoClient(url)
"""
import streamlit as st
deprecation_util.show_deprecation_warning(
f"""
`st.cache` is deprecated and will be removed soon. Please use one of Streamlit's new
caching commands, `st.cache_data` or `st.cache_resource`. More information
[in our docs]({CACHE_DOCS_URL}).
**Note**: The behavior of `st.cache` was updated in Streamlit 1.36 to the new caching
logic used by `st.cache_data` and `st.cache_resource`. This might lead to some problems
or unexpected behavior in certain edge cases.
"""
)
# suppress_st_warning is unused since its not supported by the new caching commands
if allow_output_mutation:
return st.cache_resource( # type: ignore
func,
show_spinner=show_spinner,
hash_funcs=hash_funcs,
max_entries=max_entries,
ttl=ttl,
)
return st.cache_data( # type: ignore
func,
persist=persist,
show_spinner=show_spinner,
hash_funcs=hash_funcs,
max_entries=max_entries,
ttl=ttl,
)

View File

@@ -0,0 +1,29 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from streamlit.runtime.caching.storage.cache_storage_protocol import (
CacheStorage,
CacheStorageContext,
CacheStorageError,
CacheStorageKeyNotFoundError,
CacheStorageManager,
)
__all__ = [
"CacheStorage",
"CacheStorageContext",
"CacheStorageError",
"CacheStorageKeyNotFoundError",
"CacheStorageManager",
]

View File

@@ -0,0 +1,239 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Declares the CacheStorageContext dataclass, which contains parameter information for
each function decorated by `@st.cache_data` (for example: ttl, max_entries etc.).
Declares the CacheStorageManager protocol, which implementations are used
to create CacheStorage instances and to optionally clear all cache storages,
that were created by this manager, and to check if the context is valid for the storage.
Declares the CacheStorage protocol, which implementations are used to store cached
values for a single `@st.cache_data` decorated function serialized as bytes.
How these classes work together
-------------------------------
- CacheStorageContext : this is a dataclass that contains the parameters from
`@st.cache_data` that are passed to the CacheStorageManager.create() method.
- CacheStorageManager : each instance of this is able to create CacheStorage
instances, and optionally to clear data of all cache storages.
- CacheStorage : each instance of this is able to get, set, delete, and clear
entries for a single `@st.cache_data` decorated function.
┌───────────────────────────────┐
│ │
│ CacheStorageManager │
│ │
│ - clear_all(optional) │
│ - check_context │
│ │
└──┬────────────────────────────┘
│ ┌──────────────────────┐
│ │ CacheStorage │
│ create(context)│ │
└────────────────► - get │
│ - set │
│ - delete │
│ - close (optional)│
│ - clear │
└──────────────────────┘
"""
from __future__ import annotations
from abc import abstractmethod
from dataclasses import dataclass
from typing import Literal, Protocol
class CacheStorageError(Exception):
"""Base exception raised by the cache storage."""
class CacheStorageKeyNotFoundError(CacheStorageError):
"""Raised when the key is not found in the cache storage."""
class InvalidCacheStorageContext(CacheStorageError):
"""Raised if the cache storage manager is not able to work with
provided CacheStorageContext.
"""
@dataclass(frozen=True)
class CacheStorageContext:
"""Context passed to the cache storage during initialization
This is the normalized parameters that are passed to CacheStorageManager.create()
method.
Parameters
----------
function_key: str
A hash computed based on function name and source code decorated
by `@st.cache_data`
function_display_name: str
The display name of the function that is decorated by `@st.cache_data`
ttl_seconds : float or None
The time-to-live for the keys in storage, in seconds. If None, the entry
will never expire.
max_entries : int or None
The maximum number of entries to store in the cache storage.
If None, the cache storage will not limit the number of entries.
persist : Literal["disk"] or None
The persistence mode for the cache storage.
Legacy parameter, that used in Streamlit current cache storage implementation.
Could be ignored by cache storage implementation, if storage does not support
persistence or it persistent by default.
"""
function_key: str
function_display_name: str
ttl_seconds: float | None = None
max_entries: int | None = None
persist: Literal["disk"] | None = None
class CacheStorage(Protocol):
"""Cache storage protocol, that should be implemented by the concrete cache storages.
Used to store cached values for a single `@st.cache_data` decorated function
serialized as bytes.
CacheStorage instances should be created by `CacheStorageManager.create()` method.
Notes
-----
Threading: The methods of this protocol could be called from multiple threads.
This is a responsibility of the concrete implementation to ensure thread safety
guarantees.
"""
@abstractmethod
def get(self, key: str) -> bytes:
"""Returns the stored value for the key.
Raises
------
CacheStorageKeyNotFoundError
Raised if the key is not in the storage.
"""
raise NotImplementedError
@abstractmethod
def set(self, key: str, value: bytes) -> None:
"""Sets the value for a given key."""
raise NotImplementedError
@abstractmethod
def delete(self, key: str) -> None:
"""Delete a given key."""
raise NotImplementedError
@abstractmethod
def clear(self) -> None:
"""Remove all keys for the storage."""
raise NotImplementedError
def close(self) -> None:
"""Closes the cache storage, it is optional to implement, and should be used
to close open resources, before we delete the storage instance.
e.g. close the database connection etc.
"""
pass
class CacheStorageManager(Protocol):
"""Cache storage manager protocol, that should be implemented by the concrete
cache storage managers.
It is responsible for:
- Creating cache storage instances for the specific
decorated functions,
- Validating the context for the cache storages.
- Optionally clearing all cache storages in optimal way.
It should be created during Runtime initialization.
"""
@abstractmethod
def create(self, context: CacheStorageContext) -> CacheStorage:
"""Creates a new cache storage instance
Please note that the ttl, max_entries and other context fields are specific
for whole storage, not for individual key.
Notes
-----
Threading: Should be safe to call from any thread.
"""
raise NotImplementedError
def clear_all(self) -> None:
"""Remove everything what possible from the cache storages in optimal way.
meaningful default behaviour is to raise NotImplementedError, so this is not
abstractmethod.
The method is optional to implement: cache data API will fall back to remove
all available storages one by one via storage.clear() method
if clear_all raises NotImplementedError.
Raises
------
NotImplementedError
Raised if the storage manager does not provide an ability to clear
all storages at once in optimal way.
Notes
-----
Threading: This method could be called from multiple threads.
This is a responsibility of the concrete implementation to ensure
thread safety guarantees.
"""
raise NotImplementedError
def check_context(self, context: CacheStorageContext) -> None:
"""Checks if the context is valid for the storage manager.
This method should not return anything, but log message or raise an exception
if the context is invalid.
In case of raising an exception, we not handle it and let the exception to be
propagated.
check_context is called only once at the moment of creating `@st.cache_data`
decorator for specific function, so it is not called for every cache hit.
Parameters
----------
context: CacheStorageContext
The context to check for the storage manager, dummy function_key in context
will be used, since it is not computed at the point of calling this method.
Raises
------
InvalidCacheStorageContext
Raised if the cache storage manager is not able to work with provided
CacheStorageContext. When possible we should log message instead, since
this exception will be propagated to the user.
Notes
-----
Threading: Should be safe to call from any thread.
"""
pass

View File

@@ -0,0 +1,60 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from streamlit.runtime.caching.storage.cache_storage_protocol import (
CacheStorage,
CacheStorageContext,
CacheStorageKeyNotFoundError,
CacheStorageManager,
)
from streamlit.runtime.caching.storage.in_memory_cache_storage_wrapper import (
InMemoryCacheStorageWrapper,
)
class MemoryCacheStorageManager(CacheStorageManager):
def create(self, context: CacheStorageContext) -> CacheStorage:
"""Creates a new cache storage instance wrapped with in-memory cache layer."""
persist_storage = DummyCacheStorage()
return InMemoryCacheStorageWrapper(
persist_storage=persist_storage, context=context
)
def clear_all(self) -> None:
raise NotImplementedError
def check_context(self, context: CacheStorageContext) -> None:
pass
class DummyCacheStorage(CacheStorage):
def get(self, key: str) -> bytes:
"""
Dummy gets the value for a given key,
always raises an CacheStorageKeyNotFoundError.
"""
raise CacheStorageKeyNotFoundError("Key not found in dummy cache")
def set(self, key: str, value: bytes) -> None:
pass
def delete(self, key: str) -> None:
pass
def clear(self) -> None:
pass
def close(self) -> None:
pass

View File

@@ -0,0 +1,145 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
import threading
from cachetools import TTLCache
from streamlit.logger import get_logger
from streamlit.runtime.caching import cache_utils
from streamlit.runtime.caching.storage.cache_storage_protocol import (
CacheStorage,
CacheStorageContext,
CacheStorageKeyNotFoundError,
)
from streamlit.runtime.stats import CacheStat
_LOGGER = get_logger(__name__)
class InMemoryCacheStorageWrapper(CacheStorage):
"""
In-memory cache storage wrapper.
This class wraps a cache storage and adds an in-memory cache front layer,
which is used to reduce the number of calls to the storage.
The in-memory cache is a TTL cache, which means that the entries are
automatically removed if a given time to live (TTL) has passed.
The in-memory cache is also an LRU cache, which means that the entries
are automatically removed if the cache size exceeds a given maxsize.
If the storage implements its strategy for maxsize, it is recommended
(but not necessary) that the storage implement the same LRU strategy,
otherwise a situation may arise when different items are deleted from
the memory cache and from the storage.
Notes
-----
Threading: in-memory caching layer is thread safe: we hold self._mem_cache_lock for
working with this self._mem_cache object.
However, we do not hold this lock when calling into the underlying storage,
so it is the responsibility of the that storage to ensure that it is safe to use
it from multiple threads.
"""
def __init__(self, persist_storage: CacheStorage, context: CacheStorageContext):
self.function_key = context.function_key
self.function_display_name = context.function_display_name
self._ttl_seconds = context.ttl_seconds
self._max_entries = context.max_entries
self._mem_cache: TTLCache[str, bytes] = TTLCache(
maxsize=self.max_entries,
ttl=self.ttl_seconds,
timer=cache_utils.TTLCACHE_TIMER,
)
self._mem_cache_lock = threading.Lock()
self._persist_storage = persist_storage
@property
def ttl_seconds(self) -> float:
return self._ttl_seconds if self._ttl_seconds is not None else math.inf
@property
def max_entries(self) -> float:
return float(self._max_entries) if self._max_entries is not None else math.inf
def get(self, key: str) -> bytes:
"""
Returns the stored value for the key or raise CacheStorageKeyNotFoundError if
the key is not found.
"""
try:
entry_bytes = self._read_from_mem_cache(key)
except CacheStorageKeyNotFoundError:
entry_bytes = self._persist_storage.get(key)
self._write_to_mem_cache(key, entry_bytes)
return entry_bytes
def set(self, key: str, value: bytes) -> None:
"""Sets the value for a given key."""
self._write_to_mem_cache(key, value)
self._persist_storage.set(key, value)
def delete(self, key: str) -> None:
"""Delete a given key."""
self._remove_from_mem_cache(key)
self._persist_storage.delete(key)
def clear(self) -> None:
"""Delete all keys for the in memory cache, and also the persistent storage."""
with self._mem_cache_lock:
self._mem_cache.clear()
self._persist_storage.clear()
def get_stats(self) -> list[CacheStat]:
"""Returns a list of stats in bytes for the cache memory storage per item."""
stats = []
with self._mem_cache_lock:
for item in self._mem_cache.values():
stats.append(
CacheStat(
category_name="st_cache_data",
cache_name=self.function_display_name,
byte_length=len(item),
)
)
return stats
def close(self) -> None:
"""Closes the cache storage."""
self._persist_storage.close()
def _read_from_mem_cache(self, key: str) -> bytes:
with self._mem_cache_lock:
if key in self._mem_cache:
entry = bytes(self._mem_cache[key])
_LOGGER.debug("Memory cache HIT: %s", key)
return entry
else:
_LOGGER.debug("Memory cache MISS: %s", key)
raise CacheStorageKeyNotFoundError("Key not found in mem cache")
def _write_to_mem_cache(self, key: str, entry_bytes: bytes) -> None:
with self._mem_cache_lock:
self._mem_cache[key] = entry_bytes
def _remove_from_mem_cache(self, key: str) -> None:
with self._mem_cache_lock:
self._mem_cache.pop(key, None)

View File

@@ -0,0 +1,223 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Declares the LocalDiskCacheStorageManager class, which is used
to create LocalDiskCacheStorage instances wrapped by InMemoryCacheStorageWrapper,
InMemoryCacheStorageWrapper wrapper allows to have first layer of in-memory cache,
before accessing to LocalDiskCacheStorage itself.
Declares the LocalDiskCacheStorage class, which is used to store cached
values on disk.
How these classes work together
-------------------------------
- LocalDiskCacheStorageManager : each instance of this is able
to create LocalDiskCacheStorage instances wrapped by InMemoryCacheStorageWrapper,
and to clear data from cache storage folder. It is also LocalDiskCacheStorageManager
responsibility to check if the context is valid for the storage, and to log warning
if the context is not valid.
- LocalDiskCacheStorage : each instance of this is able to get, set, delete, and clear
entries from disk for a single `@st.cache_data` decorated function if `persist="disk"`
is used in CacheStorageContext.
┌───────────────────────────────┐
│ LocalDiskCacheStorageManager │
│ │
│ - clear_all │
│ - check_context │
│ │
└──┬────────────────────────────┘
│ ┌──────────────────────────────┐
│ │ │
│ create(context)│ InMemoryCacheStorageWrapper │
└────────────────► │
│ ┌─────────────────────┐ │
│ │ │ │
│ │ LocalDiskStorage │ │
│ │ │ │
│ └─────────────────────┘ │
│ │
└──────────────────────────────┘
"""
from __future__ import annotations
import math
import os
import shutil
from typing import Final
from streamlit import errors
from streamlit.file_util import get_streamlit_file_path, streamlit_read, streamlit_write
from streamlit.logger import get_logger
from streamlit.runtime.caching.storage.cache_storage_protocol import (
CacheStorage,
CacheStorageContext,
CacheStorageError,
CacheStorageKeyNotFoundError,
CacheStorageManager,
)
from streamlit.runtime.caching.storage.in_memory_cache_storage_wrapper import (
InMemoryCacheStorageWrapper,
)
_LOGGER: Final = get_logger(__name__)
# Streamlit directory where persisted @st.cache_data objects live.
# (This is the same directory that @st.cache persisted objects live.
# But @st.cache_data uses a different extension, so they don't overlap.)
_CACHE_DIR_NAME: Final = "cache"
# The extension for our persisted @st.cache_data objects.
# (`@st.cache_data` was originally called `@st.memo`)
_CACHED_FILE_EXTENSION: Final = "memo"
class LocalDiskCacheStorageManager(CacheStorageManager):
def create(self, context: CacheStorageContext) -> CacheStorage:
"""Creates a new cache storage instance wrapped with in-memory cache layer."""
persist_storage = LocalDiskCacheStorage(context)
return InMemoryCacheStorageWrapper(
persist_storage=persist_storage, context=context
)
def clear_all(self) -> None:
cache_path = get_cache_folder_path()
if os.path.isdir(cache_path):
shutil.rmtree(cache_path)
def check_context(self, context: CacheStorageContext) -> None:
if (
context.persist == "disk"
and context.ttl_seconds is not None
and not math.isinf(context.ttl_seconds)
):
_LOGGER.warning(
f"The cached function '{context.function_display_name}' has a TTL "
"that will be ignored. Persistent cached functions currently don't "
"support TTL."
)
class LocalDiskCacheStorage(CacheStorage):
"""Cache storage that persists data to disk
This is the default cache persistence layer for `@st.cache_data`.
"""
def __init__(self, context: CacheStorageContext):
self.function_key = context.function_key
self.persist = context.persist
self._ttl_seconds = context.ttl_seconds
self._max_entries = context.max_entries
@property
def ttl_seconds(self) -> float:
return self._ttl_seconds if self._ttl_seconds is not None else math.inf
@property
def max_entries(self) -> float:
return float(self._max_entries) if self._max_entries is not None else math.inf
def get(self, key: str) -> bytes:
"""
Returns the stored value for the key if persisted,
raise CacheStorageKeyNotFoundError if not found, or not configured
with persist="disk".
"""
if self.persist == "disk":
path = self._get_cache_file_path(key)
try:
with streamlit_read(path, binary=True) as input:
value = input.read()
_LOGGER.debug("Disk cache HIT: %s", key)
return bytes(value)
except FileNotFoundError:
raise CacheStorageKeyNotFoundError("Key not found in disk cache")
except Exception as ex:
_LOGGER.exception("Error reading from cache")
raise CacheStorageError("Unable to read from cache") from ex
else:
raise CacheStorageKeyNotFoundError(
f"Local disk cache storage is disabled (persist={self.persist})"
)
def set(self, key: str, value: bytes) -> None:
"""Sets the value for a given key."""
if self.persist == "disk":
path = self._get_cache_file_path(key)
try:
with streamlit_write(path, binary=True) as output:
output.write(value)
except errors.Error as ex:
_LOGGER.debug("Unable to write to cache", exc_info=ex)
# Clean up file so we don't leave zero byte files.
try:
os.remove(path)
except (FileNotFoundError, OSError):
# If we can't remove the file, it's not a big deal.
pass
raise CacheStorageError("Unable to write to cache") from ex
def delete(self, key: str) -> None:
"""Delete a cache file from disk. If the file does not exist on disk,
return silently. If another exception occurs, log it. Does not throw.
"""
if self.persist == "disk":
path = self._get_cache_file_path(key)
try:
os.remove(path)
except FileNotFoundError:
# The file is already removed.
pass
except Exception as ex:
_LOGGER.exception(
"Unable to remove a file from the disk cache", exc_info=ex
)
def clear(self) -> None:
"""Delete all keys for the current storage."""
cache_dir = get_cache_folder_path()
if os.path.isdir(cache_dir):
# We try to remove all files in the cache directory that start with
# the function key, whether `clear` called for `self.persist`
# storage or not, to avoid leaving orphaned files in the cache directory.
for file_name in os.listdir(cache_dir):
if self._is_cache_file(file_name):
os.remove(os.path.join(cache_dir, file_name))
def close(self) -> None:
"""Dummy implementation of close, we don't need to actually "close" anything."""
def _get_cache_file_path(self, value_key: str) -> str:
"""Return the path of the disk cache file for the given value."""
cache_dir = get_cache_folder_path()
return os.path.join(
cache_dir, f"{self.function_key}-{value_key}.{_CACHED_FILE_EXTENSION}"
)
def _is_cache_file(self, fname: str) -> bool:
"""Return true if the given file name is a cache file for this storage."""
return fname.startswith(f"{self.function_key}-") and fname.endswith(
f".{_CACHED_FILE_EXTENSION}"
)
def get_cache_folder_path() -> str:
return get_streamlit_file_path(_CACHE_DIR_NAME)

View File

@@ -0,0 +1,435 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import re
from typing import TYPE_CHECKING, Any, Final, Literal, TypeVar, overload
from streamlit.connections import (
BaseConnection,
SnowflakeConnection,
SnowparkConnection,
SQLConnection,
)
from streamlit.deprecation_util import deprecate_obj_name
from streamlit.errors import StreamlitAPIException
from streamlit.runtime.caching import cache_resource
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.secrets import secrets_singleton
if TYPE_CHECKING:
from datetime import timedelta
# NOTE: Adding support for a new first party connection requires:
# 1. Adding the new connection name and class to this dict.
# 2. Writing two new @overloads for connection_factory (one for the case where the
# only the connection name is specified and another when both name and type are).
# 3. Updating test_get_first_party_connection_helper in connection_factory_test.py.
FIRST_PARTY_CONNECTIONS = {
"snowflake": SnowflakeConnection,
"snowpark": SnowparkConnection,
"sql": SQLConnection,
}
MODULE_EXTRACTION_REGEX = re.compile(r"No module named \'(.+)\'")
MODULES_TO_PYPI_PACKAGES: Final[dict[str, str]] = {
"MySQLdb": "mysqlclient",
"psycopg2": "psycopg2-binary",
"sqlalchemy": "sqlalchemy",
"snowflake": "snowflake-connector-python",
"snowflake.connector": "snowflake-connector-python",
"snowflake.snowpark": "snowflake-snowpark-python",
}
# The BaseConnection bound is parameterized to `Any` below as subclasses of
# BaseConnection are responsible for binding the type parameter of BaseConnection to a
# concrete type, but the type it gets bound to isn't important to us here.
ConnectionClass = TypeVar("ConnectionClass", bound=BaseConnection[Any])
@gather_metrics("connection")
def _create_connection(
name: str,
connection_class: type[ConnectionClass],
max_entries: int | None = None,
ttl: float | timedelta | None = None,
**kwargs,
) -> ConnectionClass:
"""Create an instance of connection_class with the given name and kwargs.
The weird implementation of this function with the @cache_resource annotated
function defined internally is done to:
- Always @gather_metrics on the call even if the return value is a cached one.
- Allow the user to specify ttl and max_entries when calling st.connection.
"""
def __create_connection(
name: str, connection_class: type[ConnectionClass], **kwargs
) -> ConnectionClass:
return connection_class(connection_name=name, **kwargs)
if not issubclass(connection_class, BaseConnection):
raise StreamlitAPIException(
f"{connection_class} is not a subclass of BaseConnection!"
)
# We modify our helper function's `__qualname__` here to work around default
# `@st.cache_resource` behavior. Otherwise, `st.connection` being called with
# different `ttl` or `max_entries` values will reset the cache with each call.
ttl_str = str(ttl).replace( # Avoid adding extra `.` characters to `__qualname__`
".", "_"
)
__create_connection.__qualname__ = (
f"{__create_connection.__qualname__}_{ttl_str}_{max_entries}"
)
__create_connection = cache_resource(
max_entries=max_entries,
show_spinner="Running `st.connection(...)`.",
ttl=ttl,
)(__create_connection)
return __create_connection(name, connection_class, **kwargs)
def _get_first_party_connection(connection_class: str):
if connection_class in FIRST_PARTY_CONNECTIONS:
return FIRST_PARTY_CONNECTIONS[connection_class]
raise StreamlitAPIException(
f"Invalid connection '{connection_class}'. "
f"Supported connection classes: {FIRST_PARTY_CONNECTIONS}"
)
@overload
def connection_factory(
name: Literal["sql"],
max_entries: int | None = None,
ttl: float | timedelta | None = None,
autocommit: bool = False,
**kwargs,
) -> SQLConnection:
pass
@overload
def connection_factory(
name: str,
type: Literal["sql"],
max_entries: int | None = None,
ttl: float | timedelta | None = None,
autocommit: bool = False,
**kwargs,
) -> SQLConnection:
pass
@overload
def connection_factory(
name: Literal["snowflake"],
max_entries: int | None = None,
ttl: float | timedelta | None = None,
autocommit: bool = False,
**kwargs,
) -> SnowflakeConnection:
pass
@overload
def connection_factory(
name: str,
type: Literal["snowflake"],
max_entries: int | None = None,
ttl: float | timedelta | None = None,
autocommit: bool = False,
**kwargs,
) -> SnowflakeConnection:
pass
@overload
def connection_factory(
name: Literal["snowpark"],
max_entries: int | None = None,
ttl: float | timedelta | None = None,
**kwargs,
) -> SnowparkConnection:
pass
@overload
def connection_factory(
name: str,
type: Literal["snowpark"],
max_entries: int | None = None,
ttl: float | timedelta | None = None,
**kwargs,
) -> SnowparkConnection:
pass
@overload
def connection_factory(
name: str,
type: type[ConnectionClass],
max_entries: int | None = None,
ttl: float | timedelta | None = None,
**kwargs,
) -> ConnectionClass:
pass
@overload
def connection_factory(
name: str,
type: str | None = None,
max_entries: int | None = None,
ttl: float | timedelta | None = None,
**kwargs,
) -> BaseConnection[Any]:
pass
def connection_factory(
name,
type=None,
max_entries=None,
ttl=None,
**kwargs,
):
"""Create a new connection to a data store or API, or return an existing one.
Configuration options, credentials, and secrets for connections are
combined from the following sources:
- The keyword arguments passed to this command.
- The app's ``secrets.toml`` files.
- Any connection-specific configuration files.
The connection returned from ``st.connection`` is internally cached with
``st.cache_resource`` and is therefore shared between sessions.
Parameters
----------
name : str
The connection name used for secrets lookup in ``secrets.toml``.
Streamlit uses secrets under ``[connections.<name>]`` for the
connection. ``type`` will be inferred if ``name`` is one of the
following: ``"snowflake"``, ``"snowpark"``, or ``"sql"``.
type : str, connection class, or None
The type of connection to create. This can be one of the following:
- ``None`` (default): Streamlit will infer the connection type from
``name``. If the type is not inferrable from ``name``, the type must
be specified in ``secrets.toml`` instead.
- ``"snowflake"``: Streamlit will initialize a connection with
|SnowflakeConnection|_.
- ``"snowpark"``: Streamlit will initialize a connection with
|SnowparkConnection|_. This is deprecated.
- ``"sql"``: Streamlit will initialize a connection with
|SQLConnection|_.
- A string path to an importable class: This must be a dot-separated
module path ending in the importable class. Streamlit will import the
class and initialize a connection with it. The class must extend
``st.connections.BaseConnection``.
- An imported class reference: Streamlit will initialize a connection
with the referenced class, which must extend
``st.connections.BaseConnection``.
.. |SnowflakeConnection| replace:: ``SnowflakeConnection``
.. _SnowflakeConnection: https://docs.streamlit.io/develop/api-reference/connections/st.connections.snowflakeconnection
.. |SnowparkConnection| replace:: ``SnowparkConnection``
.. _SnowparkConnection: https://docs.streamlit.io/develop/api-reference/connections/st.connections.snowparkconnection
.. |SQLConnection| replace:: ``SQLConnection``
.. _SQLConnection: https://docs.streamlit.io/develop/api-reference/connections/st.connections.sqlconnection
max_entries : int or None
The maximum number of connections to keep in the cache.
If this is ``None`` (default), the cache is unbounded. Otherwise, when
a new entry is added to a full cache, the oldest cached entry is
removed.
ttl : float, timedelta, or None
The maximum number of seconds to keep results in the cache.
If this is ``None`` (default), cached results do not expire with time.
**kwargs : any
Connection-specific keyword arguments that are passed to the
connection's ``._connect()`` method. ``**kwargs`` are typically
combined with (and take precendence over) key-value pairs in
``secrets.toml``. To learn more, see the specific connection's
documentation.
Returns
-------
Subclass of BaseConnection
An initialized connection object of the specified ``type``.
Examples
--------
**Example 1: Inferred connection type**
The easiest way to create a first-party (SQL, Snowflake, or Snowpark) connection is
to use their default names and define corresponding sections in your ``secrets.toml``
file. The following example creates a ``"sql"``-type connection.
``.streamlit/secrets.toml``:
>>> [connections.sql]
>>> dialect = "xxx"
>>> host = "xxx"
>>> username = "xxx"
>>> password = "xxx"
Your app code:
>>> import streamlit as st
>>> conn = st.connection("sql")
**Example 2: Named connections**
Creating a connection with a custom name requires you to explicitly
specify the type. If ``type`` is not passed as a keyword argument, it must
be set in the appropriate section of ``secrets.toml``. The following
example creates two ``"sql"``-type connections, each with their own
custom name. The first defines ``type`` in the ``st.connection`` command;
the second defines ``type`` in ``secrets.toml``.
``.streamlit/secrets.toml``:
>>> [connections.first_connection]
>>> dialect = "xxx"
>>> host = "xxx"
>>> username = "xxx"
>>> password = "xxx"
>>>
>>> [connections.second_connection]
>>> type = "sql"
>>> dialect = "yyy"
>>> host = "yyy"
>>> username = "yyy"
>>> password = "yyy"
Your app code:
>>> import streamlit as st
>>> conn1 = st.connection("first_connection", type="sql")
>>> conn2 = st.connection("second_connection")
**Example 3: Using a path to the connection class**
Passing the full module path to the connection class can be useful,
especially when working with a custom connection. Although this is not the
typical way to create first party connections, the following example
creates the same type of connection as one with ``type="sql"``. Note that
``type`` is a string path.
``.streamlit/secrets.toml``:
>>> [connections.my_sql_connection]
>>> url = "xxx+xxx://xxx:xxx@xxx:xxx/xxx"
Your app code:
>>> import streamlit as st
>>> conn = st.connection(
... "my_sql_connection", type="streamlit.connections.SQLConnection"
... )
**Example 4: Importing the connection class**
You can pass the connection class directly to the ``st.connection``
command. Doing so allows static type checking tools such as ``mypy`` to
infer the exact return type of ``st.connection``. The following example
creates the same connection as in Example 3.
``.streamlit/secrets.toml``:
>>> [connections.my_sql_connection]
>>> url = "xxx+xxx://xxx:xxx@xxx:xxx/xxx"
Your app code:
>>> import streamlit as st
>>> from streamlit.connections import SQLConnection
>>> conn = st.connection("my_sql_connection", type=SQLConnection)
"""
USE_ENV_PREFIX = "env:"
if name.startswith(USE_ENV_PREFIX):
# It'd be nice to use str.removeprefix() here, but we won't be able to do that
# until the minimium Python version we support is 3.9.
envvar_name = name[len(USE_ENV_PREFIX) :]
name = os.environ[envvar_name]
if type is None:
if name in FIRST_PARTY_CONNECTIONS:
# We allow users to simply write `st.connection("sql")` instead of
# `st.connection("sql", type="sql")`.
type = _get_first_party_connection(name)
else:
# The user didn't specify a type, so we try to pull it out from their
# secrets.toml file. NOTE: we're okay with any of the dict lookups below
# exploding with a KeyError since, if type isn't explicitly specified here,
# it must be the case that it's defined in secrets.toml and should raise an
# Exception otherwise.
secrets_singleton.load_if_toml_exists()
type = secrets_singleton["connections"][name]["type"]
# type is a nice kwarg name for the st.connection user but is annoying to work with
# since it conflicts with the builtin function name and thus gets syntax
# highlighted.
connection_class = type
if isinstance(connection_class, str):
# We assume that a connection_class specified via string is either the fully
# qualified name of a class (its module and exported classname) or the string
# literal shorthand for one of our first party connections. In the former case,
# connection_class will always contain a "." in its name.
if "." in connection_class:
parts = connection_class.split(".")
classname = parts.pop()
import importlib
connection_module = importlib.import_module(".".join(parts))
connection_class = getattr(connection_module, classname)
else:
connection_class = _get_first_party_connection(connection_class)
# At this point, connection_class should be of type Type[ConnectionClass].
try:
conn = _create_connection(
name, connection_class, max_entries=max_entries, ttl=ttl, **kwargs
)
if isinstance(conn, SnowparkConnection):
conn = deprecate_obj_name(
conn,
'connection("snowpark")',
'connection("snowflake")',
"2024-04-01",
)
return conn
except ModuleNotFoundError as e:
err_string = str(e)
missing_module = re.search(MODULE_EXTRACTION_REGEX, err_string)
extra_info = "You may be missing a dependency required to use this connection."
if missing_module:
pypi_package = MODULES_TO_PYPI_PACKAGES.get(missing_module.group(1))
if pypi_package:
extra_info = f"You need to install the '{pypi_package}' package to use this connection."
raise ModuleNotFoundError(f"{str(e)}. {extra_info}")

View File

@@ -0,0 +1,300 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Iterable, Iterator, Mapping
from functools import lru_cache
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, cast
from streamlit import runtime
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx
if TYPE_CHECKING:
from http.cookies import Morsel
from tornado.httputil import HTTPHeaders, HTTPServerRequest
from tornado.web import RequestHandler
def _get_request() -> HTTPServerRequest | None:
ctx = get_script_run_ctx()
if ctx is None:
return None
session_client = runtime.get_instance().get_client(ctx.session_id)
if session_client is None:
return None
# We return websocket request only if session_client is an instance of
# BrowserWebSocketHandler (which is True for the Streamlit open-source
# implementation). For any other implementation, we return None.
# We are not using `type_util.is_type` here to avoid circular import.
if (
f"{type(session_client).__module__}.{type(session_client).__qualname__}"
!= "streamlit.web.server.browser_websocket_handler.BrowserWebSocketHandler"
):
return None
return cast("RequestHandler", session_client).request
@lru_cache
def _normalize_header(name: str) -> str:
"""Map a header name to Http-Header-Case.
>>> _normalize_header("coNtent-TYPE")
'Content-Type'
"""
return "-".join(w.capitalize() for w in name.split("-"))
class StreamlitHeaders(Mapping[str, str]):
def __init__(self, headers: Iterable[tuple[str, str]]):
dict_like_headers: dict[str, list[str]] = {}
for key, value in headers:
header_value = dict_like_headers.setdefault(_normalize_header(key), [])
header_value.append(value)
self._headers = dict_like_headers
@classmethod
def from_tornado_headers(cls, tornado_headers: HTTPHeaders) -> StreamlitHeaders:
return cls(tornado_headers.get_all())
def get_all(self, key: str) -> list[str]:
return list(self._headers.get(_normalize_header(key), []))
def __getitem__(self, key: str) -> str:
try:
return self._headers[_normalize_header(key)][0]
except LookupError:
raise KeyError(key) from None
def __len__(self) -> int:
"""Number of unique headers present in request."""
return len(self._headers)
def __iter__(self) -> Iterator[str]:
return iter(self._headers)
def to_dict(self) -> dict[str, str]:
return {key: self[key] for key in self}
class StreamlitCookies(Mapping[str, str]):
def __init__(self, cookies: Mapping[str, str]):
self._cookies = MappingProxyType(cookies)
@classmethod
def from_tornado_cookies(
cls, tornado_cookies: dict[str, Morsel[Any]]
) -> StreamlitCookies:
dict_like_cookies = {}
for key, morsel in tornado_cookies.items():
dict_like_cookies[key] = morsel.value
return cls(dict_like_cookies)
def __getitem__(self, key: str) -> str:
return self._cookies[key]
def __len__(self) -> int:
"""Number of unique headers present in request."""
return len(self._cookies)
def __iter__(self) -> Iterator[str]:
return iter(self._cookies)
def to_dict(self) -> dict[str, str]:
return dict(self._cookies)
class ContextProxy:
"""An interface to access user session context.
``st.context`` provides a read-only interface to access headers and cookies
for the current user session.
Each property (``st.context.headers`` and ``st.context.cookies``) returns
a dictionary of named values.
"""
@property
@gather_metrics("context.headers")
def headers(self) -> StreamlitHeaders:
"""A read-only, dict-like object containing headers sent in the initial request.
Keys are case-insensitive and may be repeated. When keys are repeated,
dict-like methods will only return the last instance of each key. Use
``.get_all(key="your_repeated_key")`` to see all values if the same
header is set multiple times.
Examples
--------
**Example 1: Access all available headers**
Show a dictionary of headers (with only the last instance of any
repeated key):
>>> import streamlit as st
>>>
>>> st.context.headers
**Example 2: Access a specific header**
Show the value of a specific header (or the last instance if it's
repeated):
>>> import streamlit as st
>>>
>>> st.context.headers["host"]
Show of list of all headers for a given key:
>>> import streamlit as st
>>>
>>> st.context.headers.get_all("pragma")
"""
# We have a docstring in line above as one-liner, to have a correct docstring
# in the st.write(st,context) call.
session_client_request = _get_request()
if session_client_request is None:
return StreamlitHeaders({})
return StreamlitHeaders.from_tornado_headers(session_client_request.headers)
@property
@gather_metrics("context.cookies")
def cookies(self) -> StreamlitCookies:
"""A read-only, dict-like object containing cookies sent in the initial request.
Examples
--------
**Example 1: Access all available cookies**
Show a dictionary of cookies:
>>> import streamlit as st
>>>
>>> st.context.cookies
**Example 2: Access a specific cookie**
Show the value of a specific cookie:
>>> import streamlit as st
>>>
>>> st.context.cookies["_ga"]
"""
# We have a docstring in line above as one-liner, to have a correct docstring
# in the st.write(st,context) call.
session_client_request = _get_request()
if session_client_request is None:
return StreamlitCookies({})
cookies = session_client_request.cookies
return StreamlitCookies.from_tornado_cookies(cookies)
@property
@gather_metrics("context.timezone")
def timezone(self) -> str | None:
"""The read-only timezone of the user's browser.
Example
-------
Access the user's timezone, and format a datetime to display locally:
>>> import streamlit as st
>>> from datetime import datetime, timezone
>>> import pytz
>>>
>>> tz = st.context.timezone
>>> tz_obj = pytz.timezone(tz)
>>>
>>> now = datetime.now(timezone.utc)
>>>
>>> f"The user's timezone is {tz}."
>>> f"The UTC time is {now}."
>>> f"The user's local time is {now.astimezone(tz_obj)}"
"""
ctx = get_script_run_ctx()
if ctx is None or ctx.context_info is None:
return None
return ctx.context_info.timezone
@property
@gather_metrics("context.timezone_offset")
def timezone_offset(self) -> int | None:
"""The read-only timezone offset of the user's browser.
Example
-------
Access the user's timezone offset, and format a datetime to display locally:
>>> import streamlit as st
>>> from datetime import datetime, timezone, timedelta
>>>
>>> tzoff = st.context.timezone_offset
>>> tz_obj = timezone(-timedelta(minutes=tzoff))
>>>
>>> now = datetime.now(timezone.utc)
>>>
>>> f"The user's timezone is {tz}."
>>> f"The UTC time is {now}."
>>> f"The user's local time is {now.astimezone(tz_obj)}"
"""
ctx = get_script_run_ctx()
if ctx is None or ctx.context_info is None:
return None
return ctx.context_info.timezone_offset
@property
@gather_metrics("context.locale")
def locale(self) -> str | None:
"""The read-only locale of the user's browser.
``st.context.locale`` returns the value of |navigator.language|_ from
the user's DOM. This is a string representing the user's preferred
language (e.g. "en-US").
.. |navigator.language| replace:: ``navigator.language``
.. _navigator.language: https://developer.mozilla.org/en-US/docs/Web/API/Navigator/language
Example
-------
Access the user's locale to display locally:
>>> import streamlit as st
>>>
>>> if st.context.locale == "fr-FR":
>>> st.write("Bonjour!")
>>> else:
>>> st.write("Hello!")
"""
ctx = get_script_run_ctx()
if ctx is None or ctx.context_info is None:
return None
return ctx.context_info.locale

View File

@@ -0,0 +1,364 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Manage the user's Streamlit credentials."""
from __future__ import annotations
import json
import os
import sys
import textwrap
from typing import Final, NamedTuple, NoReturn
from uuid import uuid4
from streamlit import cli_util, env_util, file_util, util
from streamlit.logger import get_logger
_LOGGER: Final = get_logger(__name__)
if env_util.IS_WINDOWS:
_CONFIG_FILE_PATH = r"%userprofile%/.streamlit/config.toml"
else:
_CONFIG_FILE_PATH = "~/.streamlit/config.toml"
class _Activation(NamedTuple):
email: str | None # the user's email.
is_valid: bool # whether the email is valid.
def email_prompt() -> str:
# Emoji can cause encoding errors on non-UTF-8 terminals
# (See https://github.com/streamlit/streamlit/issues/2284.)
# WT_SESSION is a Windows Terminal specific environment variable. If it exists,
# we are on the latest Windows Terminal that supports emojis
show_emoji = sys.stdout.encoding == "utf-8" and (
not env_util.IS_WINDOWS or os.environ.get("WT_SESSION")
)
# IMPORTANT: Break the text below at 80 chars.
return f"""
{"👋 " if show_emoji else ""}{cli_util.style_for_cli("Welcome to Streamlit!", bold=True)}
If youd like to receive helpful onboarding emails, news, offers, promotions,
and the occasional swag, please enter your email address below. Otherwise,
leave this field blank.
{cli_util.style_for_cli("Email: ", fg="blue")}"""
_TELEMETRY_HEADLESS_TEXT = """
Collecting usage statistics. To deactivate, set browser.gatherUsageStats to false.
"""
def _send_email(email: str) -> None:
"""Send the user's email for metrics, if submitted."""
import requests
if email is None or "@" not in email:
return
metrics_url = ""
try:
response_json = requests.get(
"https://data.streamlit.io/metrics.json", timeout=2
).json()
metrics_url = response_json.get("url", "")
except Exception:
_LOGGER.error("Failed to fetch metrics URL")
return
headers = {
"accept": "*/*",
"accept-language": "en-US,en;q=0.9",
"content-type": "application/json",
"origin": "localhost:8501",
"referer": "localhost:8501/",
}
data = {
"anonymous_id": None,
"messageId": str(uuid4()),
"event": "submittedEmail",
"author_email": email,
"source": "provided_email",
"type": "track",
"userId": email,
}
response = requests.post(
metrics_url,
headers=headers,
data=json.dumps(data).encode(),
)
response.raise_for_status()
class Credentials:
"""Credentials class."""
_singleton: Credentials | None = None
@classmethod
def get_current(cls):
"""Return the singleton instance."""
if cls._singleton is None:
Credentials()
return Credentials._singleton
def __init__(self):
"""Initialize class."""
if Credentials._singleton is not None:
raise RuntimeError(
"Credentials already initialized. Use .get_current() instead"
)
self.activation = None
self._conf_file = _get_credential_file_path()
Credentials._singleton = self
def __repr__(self) -> str:
return util.repr_(self)
def load(self, auto_resolve: bool = False) -> None:
"""Load from toml file."""
if self.activation is not None:
_LOGGER.error("Credentials already loaded. Not rereading file.")
return
import toml
try:
with open(self._conf_file) as f:
data = toml.load(f).get("general")
if data is None:
raise Exception
self.activation = _verify_email(data.get("email"))
except FileNotFoundError:
if auto_resolve:
self.activate(show_instructions=not auto_resolve)
return
raise RuntimeError(
'Credentials not found. Please run "streamlit activate".'
)
except Exception:
if auto_resolve:
self.reset()
self.activate(show_instructions=not auto_resolve)
return
raise Exception(
textwrap.dedent(
"""
Unable to load credentials from %s.
Run "streamlit reset" and try again.
"""
)
% (self._conf_file)
)
def _check_activated(self, auto_resolve: bool = True) -> None:
"""Check if streamlit is activated.
Used by `streamlit run script.py`
"""
try:
self.load(auto_resolve)
except (Exception, RuntimeError) as e:
_exit(str(e))
if self.activation is None or not self.activation.is_valid:
_exit("Activation email not valid.")
@classmethod
def reset(cls) -> None:
"""Reset credentials by removing file.
This is used by `streamlit activate reset` in case a user wants
to start over.
"""
c = Credentials.get_current()
c.activation = None
try:
os.remove(c._conf_file)
except OSError:
_LOGGER.exception("Error removing credentials file.")
def save(self) -> None:
"""Save to toml file and send email."""
from requests.exceptions import RequestException
if self.activation is None:
return
# Create intermediate directories if necessary
os.makedirs(os.path.dirname(self._conf_file), exist_ok=True)
# Write the file
data = {"email": self.activation.email}
import toml
with open(self._conf_file, "w") as f:
toml.dump({"general": data}, f)
try:
_send_email(self.activation.email)
except RequestException:
_LOGGER.exception("Error saving email:")
def activate(self, show_instructions: bool = True) -> None:
"""Activate Streamlit.
Used by `streamlit activate`.
"""
try:
self.load()
except RuntimeError:
# Runtime Error is raised if credentials file is not found. In that case,
# `self.activation` is None and we will show the activation prompt below.
pass
if self.activation:
if self.activation.is_valid:
_exit("Already activated")
else:
_exit(
"Activation not valid. Please run "
"`streamlit activate reset` then `streamlit activate`"
)
else:
activated = False
while not activated:
import click
email = click.prompt(
text=email_prompt(),
prompt_suffix="",
default="",
show_default=False,
)
self.activation = _verify_email(email)
if self.activation.is_valid:
self.save()
# IMPORTANT: Break the text below at 80 chars.
TELEMETRY_TEXT = """
You can find our privacy policy at %(link)s
Summary:
- This open source library collects usage statistics.
- We cannot see and do not store information contained inside Streamlit apps,
such as text, charts, images, etc.
- Telemetry data is stored in servers in the United States.
- If you'd like to opt out, add the following to %(config)s,
creating that file if necessary:
[browser]
gatherUsageStats = false
""" % {
"link": cli_util.style_for_cli(
"https://streamlit.io/privacy-policy", underline=True
),
"config": cli_util.style_for_cli(_CONFIG_FILE_PATH),
}
cli_util.print_to_cli(TELEMETRY_TEXT)
if show_instructions:
# IMPORTANT: Break the text below at 80 chars.
INSTRUCTIONS_TEXT = """
%(start)s
%(prompt)s %(hello)s
""" % {
"start": cli_util.style_for_cli(
"Get started by typing:", fg="blue", bold=True
),
"prompt": cli_util.style_for_cli("$", fg="blue"),
"hello": cli_util.style_for_cli(
"streamlit hello", bold=True
),
}
cli_util.print_to_cli(INSTRUCTIONS_TEXT)
activated = True
else: # pragma: nocover
_LOGGER.error("Please try again.")
def _verify_email(email: str) -> _Activation:
"""Verify the user's email address.
The email can either be an empty string (if the user chooses not to enter
it), or a string with a single '@' somewhere in it.
Parameters
----------
email : str
Returns
-------
_Activation
An _Activation object. Its 'is_valid' property will be True only if
the email was validated.
"""
email = email.strip()
# We deliberately use simple email validation here
# since we do not use email address anywhere to send emails.
if len(email) > 0 and email.count("@") != 1:
_LOGGER.error("That doesn't look like an email :(")
return _Activation(None, False)
return _Activation(email, True)
def _exit(message: str) -> NoReturn:
"""Exit program with error."""
_LOGGER.error(message)
sys.exit(-1)
def _get_credential_file_path() -> str:
return file_util.get_streamlit_file_path("credentials.toml")
def _check_credential_file_exists() -> bool:
return os.path.exists(_get_credential_file_path())
def check_credentials() -> None:
"""Check credentials and potentially activate.
Note
----
If there is no credential file and we are in headless mode, we should not
check, since credential would be automatically set to an empty string.
"""
from streamlit import config
if not _check_credential_file_exists() and config.get_option("server.headless"):
if not config.is_manually_set("browser.gatherUsageStats"):
# If not manually defined, show short message about usage stats gathering.
cli_util.print_to_cli(_TELEMETRY_HEADLESS_TEXT)
return
Credentials.get_current()._check_activated()

View File

@@ -0,0 +1,296 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, Final
from weakref import WeakKeyDictionary
from streamlit import config, util
from streamlit.logger import get_logger
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats
if TYPE_CHECKING:
from collections.abc import MutableMapping
from streamlit.runtime.app_session import AppSession
_LOGGER: Final = get_logger(__name__)
def populate_hash_if_needed(msg: ForwardMsg) -> str:
"""Computes and assigns the unique hash for a ForwardMsg.
If the ForwardMsg already has a hash, this is a no-op.
Parameters
----------
msg : ForwardMsg
Returns
-------
string
The message's hash, returned here for convenience. (The hash
will also be assigned to the ForwardMsg; callers do not need
to do this.)
"""
if msg.hash == "":
# Move the message's metadata aside. It's not part of the
# hash calculation.
metadata = msg.metadata
msg.ClearField("metadata")
# MD5 is good enough for what we need, which is uniqueness.
msg.hash = util.calc_md5(msg.SerializeToString())
# Restore metadata.
msg.metadata.CopyFrom(metadata)
return msg.hash
def create_reference_msg(msg: ForwardMsg) -> ForwardMsg:
"""Create a ForwardMsg that refers to the given message via its hash.
The reference message will also get a copy of the source message's
metadata.
Parameters
----------
msg : ForwardMsg
The ForwardMsg to create the reference to.
Returns
-------
ForwardMsg
A new ForwardMsg that "points" to the original message via the
ref_hash field.
"""
ref_msg = ForwardMsg()
ref_msg.ref_hash = populate_hash_if_needed(msg)
ref_msg.metadata.CopyFrom(msg.metadata)
return ref_msg
class ForwardMsgCache(CacheStatsProvider):
"""A cache of ForwardMsgs.
Large ForwardMsgs (e.g. those containing big DataFrame payloads) are
stored in this cache. The server can choose to send a ForwardMsg's hash,
rather than the message itself, to a client. Clients can then
request messages from this cache via another endpoint.
This cache is *not* thread safe. It's intended to only be accessed by
the server thread.
"""
class Entry:
"""Cache entry.
Stores the cached message, and the set of AppSessions
that we've sent the cached message to.
"""
def __init__(self, msg: ForwardMsg | None):
self.msg = msg
self._session_script_run_counts: MutableMapping[AppSession, int] = (
WeakKeyDictionary()
)
def __repr__(self) -> str:
return util.repr_(self)
def add_session_ref(self, session: AppSession, script_run_count: int) -> None:
"""Adds a reference to a AppSession that has referenced
this Entry's message.
Parameters
----------
session : AppSession
script_run_count : int
The session's run count at the time of the call
"""
prev_run_count = self._session_script_run_counts.get(session, 0)
if script_run_count < prev_run_count:
_LOGGER.error(
"New script_run_count (%s) is < prev_run_count (%s). "
"This should never happen!",
script_run_count,
prev_run_count,
)
script_run_count = prev_run_count
self._session_script_run_counts[session] = script_run_count
def has_session_ref(self, session: AppSession) -> bool:
return session in self._session_script_run_counts
def get_session_ref_age(
self, session: AppSession, script_run_count: int
) -> int:
"""The age of the given session's reference to the Entry,
given a new script_run_count.
"""
return script_run_count - self._session_script_run_counts[session]
def remove_session_ref(self, session: AppSession) -> None:
del self._session_script_run_counts[session]
def has_refs(self) -> bool:
"""True if this Entry has references from any AppSession.
If not, it can be removed from the cache.
"""
return len(self._session_script_run_counts) > 0
def __init__(self):
self._entries: dict[str, ForwardMsgCache.Entry] = {}
def __repr__(self) -> str:
return util.repr_(self)
def add_message(
self, msg: ForwardMsg, session: AppSession, script_run_count: int
) -> None:
"""Add a ForwardMsg to the cache.
The cache will also record a reference to the given AppSession,
so that it can track which sessions have already received
each given ForwardMsg.
Parameters
----------
msg : ForwardMsg
session : AppSession
script_run_count : int
The number of times the session's script has run
"""
populate_hash_if_needed(msg)
entry = self._entries.get(msg.hash, None)
if entry is None:
if config.get_option("global.storeCachedForwardMessagesInMemory"):
entry = ForwardMsgCache.Entry(msg)
else:
entry = ForwardMsgCache.Entry(None)
self._entries[msg.hash] = entry
entry.add_session_ref(session, script_run_count)
def get_message(self, hash: str) -> ForwardMsg | None:
"""Return the message with the given ID if it exists in the cache.
Parameters
----------
hash : str
The id of the message to retrieve.
Returns
-------
ForwardMsg | None
"""
entry = self._entries.get(hash, None)
return entry.msg if entry else None
def has_message_reference(
self, msg: ForwardMsg, session: AppSession, script_run_count: int
) -> bool:
"""Return True if a session has a reference to a message."""
populate_hash_if_needed(msg)
entry = self._entries.get(msg.hash, None)
if entry is None or not entry.has_session_ref(session):
return False
# Ensure we're not expired
age = entry.get_session_ref_age(session, script_run_count)
return age <= int(config.get_option("global.maxCachedMessageAge"))
def remove_refs_for_session(self, session: AppSession) -> None:
"""Remove refs for all entries for the given session.
This should be called when an AppSession is disconnected or closed.
Parameters
----------
session : AppSession
"""
# Operate on a copy of our entries dict.
# We may be deleting from it.
for msg_hash, entry in self._entries.copy().items():
if entry.has_session_ref(session):
entry.remove_session_ref(session)
if not entry.has_refs():
# The entry has no more references. Remove it from
# the cache completely.
del self._entries[msg_hash]
def remove_expired_entries_for_session(
self, session: AppSession, script_run_count: int
) -> None:
"""Remove any cached messages that have expired from the given session.
This should be called each time a AppSession finishes executing.
Parameters
----------
session : AppSession
script_run_count : int
The number of times the session's script has run
"""
max_age = config.get_option("global.maxCachedMessageAge")
# Operate on a copy of our entries dict.
# We may be deleting from it.
for msg_hash, entry in self._entries.copy().items():
if not entry.has_session_ref(session):
continue
age = entry.get_session_ref_age(session, script_run_count)
if age > max_age:
_LOGGER.debug(
"Removing expired entry [session=%s, hash=%s, age=%s]",
id(session),
msg_hash,
age,
)
entry.remove_session_ref(session)
if not entry.has_refs():
# The entry has no more references. Remove it from
# the cache completely.
del self._entries[msg_hash]
def clear(self) -> None:
"""Remove all entries from the cache."""
self._entries.clear()
def get_stats(self) -> list[CacheStat]:
stats: list[CacheStat] = [
CacheStat(
category_name="ForwardMessageCache",
cache_name="",
byte_length=entry.msg.ByteSize() if entry.msg is not None else 0,
)
for _, entry in self._entries.items()
]
return group_stats(stats)

View File

@@ -0,0 +1,241 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
if TYPE_CHECKING:
from streamlit.proto.Delta_pb2 import Delta
class ForwardMsgQueue:
"""Accumulates a session's outgoing ForwardMsgs.
Each AppSession adds messages to its queue, and the Server periodically
flushes all session queues and delivers their messages to the appropriate
clients.
ForwardMsgQueue is not thread-safe - a queue should only be used from
a single thread.
"""
_before_enqueue_msg: Callable[[ForwardMsg], None] | None = None
@staticmethod
def on_before_enqueue_msg(
before_enqueue_msg: Callable[[ForwardMsg], None] | None,
) -> None:
"""Set a callback to be called before a message is enqueued.
Used in static streamlit app generation.
"""
ForwardMsgQueue._before_enqueue_msg = before_enqueue_msg
def __init__(self):
self._queue: list[ForwardMsg] = []
# A mapping of (delta_path -> _queue.indexof(msg)) for each
# Delta message in the queue. We use this for coalescing
# redundant outgoing Deltas (where a newer Delta supersedes
# an older Delta, with the same delta_path, that's still in the
# queue).
self._delta_index_map: dict[tuple[int, ...], int] = {}
def get_debug(self) -> dict[str, Any]:
from google.protobuf.json_format import MessageToDict
return {
"queue": [MessageToDict(m) for m in self._queue],
"ids": list(self._delta_index_map.keys()),
}
def is_empty(self) -> bool:
return len(self._queue) == 0
def enqueue(self, msg: ForwardMsg) -> None:
"""Add message into queue, possibly composing it with another message."""
if ForwardMsgQueue._before_enqueue_msg:
ForwardMsgQueue._before_enqueue_msg(msg)
if not _is_composable_message(msg):
self._queue.append(msg)
return
# If there's a Delta message with the same delta_path already in
# the queue - meaning that it refers to the same location in
# the app - we attempt to combine this new Delta into the old
# one. This is an optimization that prevents redundant Deltas
# from being sent to the frontend.
delta_key = tuple(msg.metadata.delta_path)
if delta_key in self._delta_index_map:
index = self._delta_index_map[delta_key]
old_msg = self._queue[index]
composed_delta = _maybe_compose_deltas(old_msg.delta, msg.delta)
if composed_delta is not None:
new_msg = ForwardMsg()
new_msg.delta.CopyFrom(composed_delta)
new_msg.metadata.CopyFrom(msg.metadata)
self._queue[index] = new_msg
return
# No composition occurred. Append this message to the queue, and
# store its index for potential future composition.
self._delta_index_map[delta_key] = len(self._queue)
self._queue.append(msg)
def clear(
self,
retain_lifecycle_msgs: bool = False,
fragment_ids_this_run: list[str] | None = None,
) -> None:
"""Clear the queue, potentially retaining lifecycle messages.
The retain_lifecycle_msgs argument exists because in some cases (in particular
when a currently running script is interrupted by a new BackMsg), we don't want
to remove certain messages from the queue as doing so may cause the client to
not hear about important script lifecycle events (such as the script being
stopped early in order to be rerun).
If fragment_ids_this_run is provided, delta messages not belonging to any
fragment or belonging to a fragment not in fragment_ids_this_run will be
preserved to prevent clearing messages unrelated to the running fragments.
"""
if not retain_lifecycle_msgs:
self._queue = []
else:
self._queue = [
_update_script_finished_message(msg, fragment_ids_this_run is not None)
for msg in self._queue
if msg.WhichOneof("type")
in {
"new_session",
"script_finished",
"session_status_changed",
"parent_message",
}
or (
# preserve all messages if this is a fragment rerun and...
fragment_ids_this_run is not None
and (
# the message is not a delta message
# (not associated with a fragment) or...
msg.delta is None
or (
# it is a delta but not associated with any of the passed
# fragments
msg.delta is not None
and (
msg.delta.fragment_id is None
or msg.delta.fragment_id not in fragment_ids_this_run
)
)
)
)
]
self._delta_index_map = {}
def flush(self) -> list[ForwardMsg]:
"""Clear the queue and return a list of the messages it contained
before being cleared.
"""
queue = self._queue
self.clear()
return queue
def __len__(self) -> int:
return len(self._queue)
def _is_composable_message(msg: ForwardMsg) -> bool:
"""True if the ForwardMsg is potentially composable with other ForwardMsgs."""
if not msg.HasField("delta"):
# Non-delta messages are never composable.
return False
# We never compose add_rows messages in Python, because the add_rows
# operation can raise errors, and we don't have a good way of handling
# those errors in the message queue.
delta_type = msg.delta.WhichOneof("type")
return delta_type != "add_rows" and delta_type != "arrow_add_rows"
def _maybe_compose_deltas(old_delta: Delta, new_delta: Delta) -> Delta | None:
"""Combines new_delta onto old_delta if possible.
If the combination takes place, the function returns a new Delta that
should replace old_delta in the queue.
If the new_delta is incompatible with old_delta, the function returns None.
In this case, the new_delta should just be appended to the queue as normal.
"""
old_delta_type = old_delta.WhichOneof("type")
if old_delta_type == "add_block":
# We never replace add_block deltas, because blocks can have
# other dependent deltas later in the queue. For example:
#
# placeholder = st.empty()
# placeholder.columns(1)
# placeholder.empty()
#
# The call to "placeholder.columns(1)" creates two blocks, a parent
# container with delta_path (0, 0), and a column child with
# delta_path (0, 0, 0). If the final "placeholder.empty()" Delta
# is composed with the parent container Delta, the frontend will
# throw an error when it tries to add that column child to what is
# now just an element, and not a block.
return None
new_delta_type = new_delta.WhichOneof("type")
if new_delta_type == "new_element":
return new_delta
if new_delta_type == "add_block":
return new_delta
return None
def _update_script_finished_message(
msg: ForwardMsg, is_fragment_run: bool
) -> ForwardMsg:
"""
When we are here, the message queue is cleared from non-lifecycle messages
before they were flushed to the browser.
If there were no non-lifecycle messages in the queue, changing the type here
should not matter for the frontend anyways, so we optimistically change the
`script_finished` message to `FINISHED_EARLY_FOR_RERUN`. This indicates to
the frontend that the previous run was interrupted by a new script start.
Otherwise, a `FINISHED_SUCCESSFULLY` message might trigger a reset of widget
states on the frontend.
"""
if msg.WhichOneof("type") == "script_finished" and (
# If this is not a fragment run (= full app run), its okay to change the
# script_finished type to FINISHED_EARLY_FOR_RERUN because another full app run
# is about to start.
# If this is a fragment run, it is allowed to change the state of
# all script_finished states except for FINISHED_SUCCESSFULLY, which we use to
# indicate that a full app run has finished successfully (in other words, a
# fragment should not modify the finished status of a full app run, because
# the fragment finished state is different and the frontend might not trigger
# cleanups etc. correctly).
is_fragment_run is False
or msg.script_finished != ForwardMsg.ScriptFinishedStatus.FINISHED_SUCCESSFULLY
):
msg.script_finished = ForwardMsg.ScriptFinishedStatus.FINISHED_EARLY_FOR_RERUN
return msg

View File

@@ -0,0 +1,478 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import inspect
from abc import abstractmethod
from copy import deepcopy
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Protocol, TypeVar, overload
from streamlit.deprecation_util import (
make_deprecated_name_warning,
show_deprecation_warning,
)
from streamlit.error_util import handle_uncaught_app_exception
from streamlit.errors import FragmentHandledException, FragmentStorageKeyError
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.scriptrunner_utils.exceptions import (
RerunException,
StopException,
)
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx
from streamlit.time_util import time_to_seconds
from streamlit.util import calc_md5
if TYPE_CHECKING:
from datetime import timedelta
F = TypeVar("F", bound=Callable[..., Any])
Fragment = Callable[[], Any]
class FragmentStorage(Protocol):
"""A key-value store for Fragments. Used to implement the @st.fragment decorator.
We intentionally define this as its own protocol despite how generic it appears to
be at first glance. The reason why is that, in any case where fragments aren't just
stored as Python closures in memory, storing and retrieving Fragments will generally
involve serializing and deserializing function bytecode, which is a tricky aspect
to implementing FragmentStorages that won't generally appear with our other *Storage
protocols.
"""
# Weirdly, we have to define this above the `set` method, or mypy gets it confused
# with the `set` type of `new_fragments_ids`.
@abstractmethod
def clear(self, new_fragment_ids: set[str] | None = None) -> None:
"""Remove all fragments saved in this FragmentStorage unless listed in
new_fragment_ids.
"""
raise NotImplementedError
@abstractmethod
def get(self, key: str) -> Fragment:
"""Returns the stored fragment for the given key."""
raise NotImplementedError
@abstractmethod
def set(self, key: str, value: Fragment) -> None:
"""Saves a fragment under the given key."""
raise NotImplementedError
@abstractmethod
def delete(self, key: str) -> None:
"""Delete the fragment corresponding to the given key."""
raise NotImplementedError
@abstractmethod
def contains(self, key: str) -> bool:
"""Return whether the given key is present in this FragmentStorage."""
raise NotImplementedError
# NOTE: Ideally, we'd like to add a MemoryFragmentStorageStatProvider implementation to
# keep track of memory usage due to fragments, but doing something like this ends up
# being difficult in practice as the memory usage of a closure is hard to measure (the
# vendored implementation of pympler.asizeof that we use elsewhere is unable to measure
# the size of a function).
class MemoryFragmentStorage(FragmentStorage):
"""A simple, memory-backed implementation of FragmentStorage.
MemoryFragmentStorage is just a wrapper around a plain Python dict that complies with
the FragmentStorage protocol.
"""
def __init__(self) -> None:
self._fragments: dict[str, Fragment] = {}
# Weirdly, we have to define this above the `set` method, or mypy gets it confused
# with the `set` type of `new_fragments_ids`.
def clear(self, new_fragment_ids: set[str] | None = None) -> None:
if new_fragment_ids is None:
new_fragment_ids = set()
fragment_ids = list(self._fragments.keys())
for fid in fragment_ids:
if fid not in new_fragment_ids:
del self._fragments[fid]
def get(self, key: str) -> Fragment:
try:
return self._fragments[key]
except KeyError as e:
raise FragmentStorageKeyError(str(e))
def set(self, key: str, value: Fragment) -> None:
self._fragments[key] = value
def delete(self, key: str) -> None:
try:
del self._fragments[key]
except KeyError as e:
raise FragmentStorageKeyError(str(e))
def contains(self, key: str) -> bool:
return key in self._fragments
def _fragment(
func: F | None = None,
*,
run_every: int | float | timedelta | str | None = None,
additional_hash_info: str = "",
should_show_deprecation_warning: bool = False,
) -> Callable[[F], F] | F:
"""Contains the actual fragment logic.
This function should be used by our internal functions that use fragments
under-the-hood, so that fragment metrics are not tracked for those elements
(note that the @gather_metrics annotation is only on the publicly exposed function)
"""
if func is None:
# Support passing the params via function decorator
def wrapper(f: F) -> F:
return fragment(
func=f,
run_every=run_every,
)
return wrapper
else:
non_optional_func = func
@wraps(non_optional_func)
def wrap(*args, **kwargs):
from streamlit.delta_generator_singletons import context_dg_stack
ctx = get_script_run_ctx()
if ctx is None:
return None
cursors_snapshot = deepcopy(ctx.cursors)
dg_stack_snapshot = deepcopy(context_dg_stack.get())
fragment_id = calc_md5(
f"{non_optional_func.__module__}.{non_optional_func.__qualname__}{dg_stack_snapshot[-1]._get_delta_path_str()}{additional_hash_info}"
)
# We intentionally want to capture the active script hash here to ensure
# that the fragment is associated with the correct script running.
initialized_active_script_hash = ctx.active_script_hash
def wrapped_fragment():
import streamlit as st
if should_show_deprecation_warning:
show_deprecation_warning(
make_deprecated_name_warning(
"experimental_fragment",
"fragment",
"2025-01-01",
)
)
# NOTE: We need to call get_script_run_ctx here again and can't just use the
# value of ctx from above captured by the closure because subsequent
# fragment runs will generally run in a new script run, thus we'll have a
# new ctx.
ctx = get_script_run_ctx(suppress_warning=True)
assert ctx is not None
if ctx.fragment_ids_this_run:
# This script run is a run of one or more fragments. We restore the
# state of ctx.cursors and dg_stack to the snapshots we took when this
# fragment was declared.
ctx.cursors = deepcopy(cursors_snapshot)
context_dg_stack.set(deepcopy(dg_stack_snapshot))
# Always add the fragment id to new_fragment_ids. For full app runs
# we need to add them anyways and for fragment runs we add them
# in case the to-be-executed fragment id was cleared from the storage
# by the full app run.
ctx.new_fragment_ids.add(fragment_id)
# Set ctx.current_fragment_id so that elements corresponding to this
# fragment get tagged with the appropriate ID. ctx.current_fragment_id gets
# reset after the fragment function finishes running to either return to the
# script (outside of any fragments) or to the outer fragment this one is
# nested in.
prev_fragment_id = ctx.current_fragment_id
ctx.current_fragment_id = fragment_id
try:
# Make sure we set the active script hash to the same value
# for the fragment run as when defined upon initialization
# This ensures that elements (especially widgets) are tied
# to a consistent active script hash
active_hash_context = (
ctx.run_with_active_hash(initialized_active_script_hash)
if initialized_active_script_hash != ctx.active_script_hash
else contextlib.nullcontext()
)
result = None
with active_hash_context:
with st.container():
try:
# use dg_stack instead of active_dg to have correct copy
# during execution (otherwise we can run into concurrency
# issues with multiple fragments). Use dg_stack because we
# just entered a container and [:-1] of the delta path
# because thats the prefix of the fragment,
# e.g. [0, 3, 0] -> [0, 3].
# All fragment elements start with [0, 3].
active_dg = context_dg_stack.get()[-1]
ctx.current_fragment_delta_path = (
active_dg._cursor.delta_path
if active_dg._cursor
else []
)[:-1]
result = non_optional_func(*args, **kwargs)
except (
RerunException,
StopException,
) as e:
# The wrapped_fragment function is executed
# inside of a exec_func_with_error_handling call, so
# there is a correct handler for these exceptions.
raise e
except Exception as e:
# render error here so that the delta path is correct
# for full app runs, the error will be displayed by the
# main code handler
# if not is_full_app_run:
handle_uncaught_app_exception(e)
# raise here again in case we are in full app execution
# and some flags have to be set
raise FragmentHandledException(e)
return result
finally:
ctx.current_fragment_id = prev_fragment_id
ctx.current_fragment_delta_path = []
ctx.fragment_storage.set(fragment_id, wrapped_fragment)
if run_every:
msg = ForwardMsg()
msg.auto_rerun.interval = time_to_seconds(run_every)
msg.auto_rerun.fragment_id = fragment_id
ctx.enqueue(msg)
# Immediate execute the wrapped fragment since we are in a full app run
return wrapped_fragment()
with contextlib.suppress(AttributeError):
# Make this a well-behaved decorator by preserving important function
# attributes.
wrap.__dict__.update(non_optional_func.__dict__)
wrap.__signature__ = inspect.signature(non_optional_func) # type: ignore
return wrap
@overload
def fragment(
func: F,
*,
run_every: int | float | timedelta | str | None = None,
) -> F: ...
# Support being able to pass parameters to this decorator (that is, being able to write
# `@fragment(run_every=5.0)`).
@overload
def fragment(
func: None = None,
*,
run_every: int | float | timedelta | str | None = None,
) -> Callable[[F], F]: ...
@gather_metrics("fragment")
def fragment(
func: F | None = None,
*,
run_every: int | float | timedelta | str | None = None,
) -> Callable[[F], F] | F:
"""Decorator to turn a function into a fragment which can rerun independently\
of the full app.
When a user interacts with an input widget created inside a fragment,
Streamlit only reruns the fragment instead of the full app. If
``run_every`` is set, Streamlit will also rerun the fragment at the
specified interval while the session is active, even if the user is not
interacting with your app.
To trigger an app rerun from inside a fragment, call ``st.rerun()``
directly. To trigger a fragment rerun from within itself, call
``st.rerun(scope="fragment")``. Any values from the fragment that need to
be accessed from the wider app should generally be stored in Session State.
When Streamlit element commands are called directly in a fragment, the
elements are cleared and redrawn on each fragment rerun, just like all
elements are redrawn on each app rerun. The rest of the app is persisted
during a fragment rerun. When a fragment renders elements into externally
created containers, the elements will not be cleared with each fragment
rerun. Instead, elements will accumulate in those containers with each
fragment rerun, until the next app rerun.
Calling ``st.sidebar`` in a fragment is not supported. To write elements to
the sidebar with a fragment, call your fragment function inside a
``with st.sidebar`` context manager.
Fragment code can interact with Session State, imported modules, and
other Streamlit elements created outside the fragment. Note that these
interactions are additive across multiple fragment reruns. You are
responsible for handling any side effects of that behavior.
.. warning::
- Fragments can only contain widgets in their main body. Fragments
can't render widgets to externally created containers.
Parameters
----------
func: callable
The function to turn into a fragment.
run_every: int, float, timedelta, str, or None
The time interval between automatic fragment reruns. This can be one of
the following:
- ``None`` (default).
- An ``int`` or ``float`` specifying the interval in seconds.
- A string specifying the time in a format supported by `Pandas'
Timedelta constructor <https://pandas.pydata.org/docs/reference/api/pandas.Timedelta.html>`_,
e.g. ``"1d"``, ``"1.5 days"``, or ``"1h23s"``.
- A ``timedelta`` object from `Python's built-in datetime library
<https://docs.python.org/3/library/datetime.html#timedelta-objects>`_,
e.g. ``timedelta(days=1)``.
If ``run_every`` is ``None``, the fragment will only rerun from
user-triggered events.
Examples
--------
The following example demonstrates basic usage of
``@st.fragment``. As an analogy, "inflating balloons" is a slow process that happens
outside of the fragment. "Releasing balloons" is a quick process that happens inside
of the fragment.
>>> import streamlit as st
>>> import time
>>>
>>> @st.fragment
>>> def release_the_balloons():
>>> st.button("Release the balloons", help="Fragment rerun")
>>> st.balloons()
>>>
>>> with st.spinner("Inflating balloons..."):
>>> time.sleep(5)
>>> release_the_balloons()
>>> st.button("Inflate more balloons", help="Full rerun")
.. output::
https://doc-fragment-balloons.streamlit.app/
height: 220px
This next example demonstrates how elements both inside and outside of a
fragement update with each app or fragment rerun. In this app, clicking
"Rerun full app" will increment both counters and update all values
displayed in the app. In contrast, clicking "Rerun fragment" will only
increment the counter within the fragment. In this case, the ``st.write``
command inside the fragment will update the app's frontend, but the two
``st.write`` commands outside the fragment will not update the frontend.
>>> import streamlit as st
>>>
>>> if "app_runs" not in st.session_state:
>>> st.session_state.app_runs = 0
>>> st.session_state.fragment_runs = 0
>>>
>>> @st.fragment
>>> def my_fragment():
>>> st.session_state.fragment_runs += 1
>>> st.button("Rerun fragment")
>>> st.write(f"Fragment says it ran {st.session_state.fragment_runs} times.")
>>>
>>> st.session_state.app_runs += 1
>>> my_fragment()
>>> st.button("Rerun full app")
>>> st.write(f"Full app says it ran {st.session_state.app_runs} times.")
>>> st.write(f"Full app sees that fragment ran {st.session_state.fragment_runs} times.")
.. output::
https://doc-fragment.streamlit.app/
height: 400px
You can also trigger an app rerun from inside a fragment by calling
``st.rerun``.
>>> import streamlit as st
>>>
>>> if "clicks" not in st.session_state:
>>> st.session_state.clicks = 0
>>>
>>> @st.fragment
>>> def count_to_five():
>>> if st.button("Plus one!"):
>>> st.session_state.clicks += 1
>>> if st.session_state.clicks % 5 == 0:
>>> st.rerun()
>>> return
>>>
>>> count_to_five()
>>> st.header(f"Multiples of five clicks: {st.session_state.clicks // 5}")
>>>
>>> if st.button("Check click count"):
>>> st.toast(f"## Total clicks: {st.session_state.clicks}")
.. output::
https://doc-fragment-rerun.streamlit.app/
height: 400px
"""
return _fragment(func, run_every=run_every)
@overload
def experimental_fragment(
func: F,
*,
run_every: int | float | timedelta | str | None = None,
) -> F: ...
# Support being able to pass parameters to this decorator (that is, being able to write
# `@fragment(run_every=5.0)`).
@overload
def experimental_fragment(
func: None = None,
*,
run_every: int | float | timedelta | str | None = None,
) -> Callable[[F], F]: ...
@gather_metrics("experimental_fragment")
def experimental_fragment(
func: F | None = None,
*,
run_every: int | float | timedelta | str | None = None,
) -> Callable[[F], F] | F:
"""Deprecated alias for @st.fragment. See the docstring for the decorator's new name."""
return _fragment(func, run_every=run_every, should_show_deprecation_warning=True)

View File

@@ -0,0 +1,234 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides global MediaFileManager object as `media_file_manager`."""
from __future__ import annotations
import collections
import threading
from typing import Final
from streamlit.logger import get_logger
from streamlit.runtime.media_file_storage import MediaFileKind, MediaFileStorage
_LOGGER: Final = get_logger(__name__)
def _get_session_id() -> str:
"""Get the active AppSession's session_id."""
from streamlit.runtime.scriptrunner_utils.script_run_context import (
get_script_run_ctx,
)
ctx = get_script_run_ctx()
if ctx is None:
# This is only None when running "python myscript.py" rather than
# "streamlit run myscript.py". In which case the session ID doesn't
# matter and can just be a constant, as there's only ever "session".
return "dontcare"
else:
return ctx.session_id
class MediaFileMetadata:
"""Metadata that the MediaFileManager needs for each file it manages."""
def __init__(self, kind: MediaFileKind = MediaFileKind.MEDIA):
self._kind = kind
self._is_marked_for_delete = False
@property
def kind(self) -> MediaFileKind:
return self._kind
@property
def is_marked_for_delete(self) -> bool:
return self._is_marked_for_delete
def mark_for_delete(self) -> None:
self._is_marked_for_delete = True
class MediaFileManager:
"""In-memory file manager for MediaFile objects.
This keeps track of:
- Which files exist, and what their IDs are. This is important so we can
serve files by ID -- that's the whole point of this class!
- Which files are being used by which AppSession (by ID). This is
important so we can remove files from memory when no more sessions need
them.
- The exact location in the app where each file is being used (i.e. the
file's "coordinates"). This is is important so we can mark a file as "not
being used by a certain session" if it gets replaced by another file at
the same coordinates. For example, when doing an animation where the same
image is constantly replace with new frames. (This doesn't solve the case
where the file's coordinates keep changing for some reason, though! e.g.
if new elements keep being prepended to the app. Unlikely to happen, but
we should address it at some point.)
"""
def __init__(self, storage: MediaFileStorage):
self._storage = storage
# Dict of [file_id -> MediaFileMetadata]
self._file_metadata: dict[str, MediaFileMetadata] = {}
# Dict[session ID][coordinates] -> file_id.
self._files_by_session_and_coord: dict[str, dict[str, str]] = (
collections.defaultdict(dict)
)
# MediaFileManager is used from multiple threads, so all operations
# need to be protected with a Lock. (This is not an RLock, which
# means taking it multiple times from the same thread will deadlock.)
self._lock = threading.Lock()
def _get_inactive_file_ids(self) -> set[str]:
"""Compute the set of files that are stored in the manager, but are
not referenced by any active session. These are files that can be
safely deleted.
Thread safety: callers must hold `self._lock`.
"""
# Get the set of all our file IDs.
file_ids = set(self._file_metadata.keys())
# Subtract all IDs that are in use by each session
for session_file_ids_by_coord in self._files_by_session_and_coord.values():
file_ids.difference_update(session_file_ids_by_coord.values())
return file_ids
def remove_orphaned_files(self) -> None:
"""Remove all files that are no longer referenced by any active session.
Safe to call from any thread.
"""
_LOGGER.debug("Removing orphaned files...")
with self._lock:
for file_id in self._get_inactive_file_ids():
file = self._file_metadata[file_id]
if file.kind == MediaFileKind.MEDIA:
self._delete_file(file_id)
elif file.kind == MediaFileKind.DOWNLOADABLE:
if file.is_marked_for_delete:
self._delete_file(file_id)
else:
file.mark_for_delete()
def _delete_file(self, file_id: str) -> None:
"""Delete the given file from storage, and remove its metadata from
self._files_by_id.
Thread safety: callers must hold `self._lock`.
"""
_LOGGER.debug("Deleting File: %s", file_id)
self._storage.delete_file(file_id)
del self._file_metadata[file_id]
def clear_session_refs(self, session_id: str | None = None) -> None:
"""Remove the given session's file references.
(This does not remove any files from the manager - you must call
`remove_orphaned_files` for that.)
Should be called whenever ScriptRunner completes and when a session ends.
Safe to call from any thread.
"""
if session_id is None:
session_id = _get_session_id()
_LOGGER.debug("Disconnecting files for session with ID %s", session_id)
with self._lock:
if session_id in self._files_by_session_and_coord:
del self._files_by_session_and_coord[session_id]
_LOGGER.debug(
"Sessions still active: %r", self._files_by_session_and_coord.keys()
)
_LOGGER.debug(
"Files: %s; Sessions with files: %s",
len(self._file_metadata),
len(self._files_by_session_and_coord),
)
def add(
self,
path_or_data: bytes | str,
mimetype: str,
coordinates: str,
file_name: str | None = None,
is_for_static_download: bool = False,
) -> str:
"""Add a new MediaFile with the given parameters and return its URL.
If an identical file already exists, return the existing URL
and registers the current session as a user.
Safe to call from any thread.
Parameters
----------
path_or_data : bytes or str
If bytes: the media file's raw data. If str: the name of a file
to load from disk.
mimetype : str
The mime type for the file. E.g. "audio/mpeg".
This string will be used in the "Content-Type" header when the file
is served over HTTP.
coordinates : str
Unique string identifying an element's location.
Prevents memory leak of "forgotten" file IDs when element media
is being replaced-in-place (e.g. an st.image stream).
coordinates should be of the form: "1.(3.-14).5"
file_name : str or None
Optional file_name. Used to set the filename in the response header.
is_for_static_download: bool
Indicate that data stored for downloading as a file,
not as a media for rendering at page. [default: False]
Returns
-------
str
The url that the frontend can use to fetch the media.
Raises
------
If a filename is passed, any Exception raised when trying to read the
file will be re-raised.
"""
session_id = _get_session_id()
with self._lock:
kind = (
MediaFileKind.DOWNLOADABLE
if is_for_static_download
else MediaFileKind.MEDIA
)
file_id = self._storage.load_and_get_id(
path_or_data, mimetype, kind, file_name
)
metadata = MediaFileMetadata(kind=kind)
self._file_metadata[file_id] = metadata
self._files_by_session_and_coord[session_id][coordinates] = file_id
return self._storage.get_url(file_id)

View File

@@ -0,0 +1,143 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from abc import abstractmethod
from enum import Enum
from typing import Protocol
class MediaFileKind(Enum):
# st.image, st.video, st.audio files
MEDIA = "media"
# st.download_button files
DOWNLOADABLE = "downloadable"
class MediaFileStorageError(Exception):
"""Exception class for errors raised by MediaFileStorage.
When running in "development mode", the full text of these errors
is displayed in the frontend, so errors should be human-readable
(and actionable).
When running in "release mode", errors are redacted on the
frontend; we instead show a generic "Something went wrong!" message.
"""
class MediaFileStorage(Protocol):
@abstractmethod
def load_and_get_id(
self,
path_or_data: str | bytes,
mimetype: str,
kind: MediaFileKind,
filename: str | None = None,
) -> str:
"""Load the given file path or bytes into the manager and return
an ID that uniquely identifies it.
It's an error to pass a URL to this function. (Media stored at
external URLs can be served directly to the Streamlit frontend;
there's no need to store this data in MediaFileStorage.)
Parameters
----------
path_or_data
A path to a file, or the file's raw data as bytes.
mimetype
The media's mimetype. Used to set the Content-Type header when
serving the media over HTTP.
kind
The kind of file this is: either MEDIA, or DOWNLOADABLE.
filename : str or None
Optional filename. Used to set the filename in the response header.
Returns
-------
str
The unique ID of the media file.
Raises
------
MediaFileStorageError
Raised if the media can't be loaded (for example, if a file
path is invalid).
"""
raise NotImplementedError
@abstractmethod
def get_url(self, file_id: str) -> str:
"""Return a URL for a file in the manager.
Parameters
----------
file_id
The file's ID, returned from load_media_and_get_id().
Returns
-------
str
A URL that the frontend can load the file from. Because this
URL may expire, it should not be cached!
Raises
------
MediaFileStorageError
Raised if the manager doesn't contain an object with the given ID.
"""
raise NotImplementedError
@abstractmethod
def delete_file(self, file_id: str) -> None:
"""Delete a file from the manager.
This should be called when a given file is no longer referenced
by any connected client, so that the MediaFileStorage can free its
resources.
Calling `delete_file` on a file_id that doesn't exist is allowed,
and is a no-op. (This means that multiple `delete_file` calls with
the same file_id is not an error.)
Note: implementations can choose to ignore `delete_file` calls -
this function is a *suggestion*, not a *command*. Callers should
not rely on file deletion happening immediately (or at all).
Parameters
----------
file_id
The file's ID, returned from load_media_and_get_id().
Returns
-------
None
Raises
------
MediaFileStorageError
Raised if file deletion fails for any reason. Note that these
failures will generally not be shown on the frontend (file
deletion usually occurs on session disconnect).
"""
raise NotImplementedError

View File

@@ -0,0 +1,181 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaFileStorage implementation that stores files in memory."""
from __future__ import annotations
import contextlib
import hashlib
import mimetypes
import os.path
from typing import Final, NamedTuple
from streamlit.logger import get_logger
from streamlit.runtime.media_file_storage import (
MediaFileKind,
MediaFileStorage,
MediaFileStorageError,
)
from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats
_LOGGER: Final = get_logger(__name__)
# Mimetype -> filename extension map for the `get_extension_for_mimetype`
# function. We use Python's `mimetypes.guess_extension` for most mimetypes,
# but (as of Python 3.9) `mimetypes.guess_extension("audio/wav")` returns None,
# so we handle it ourselves.
PREFERRED_MIMETYPE_EXTENSION_MAP: Final = {
"audio/wav": ".wav",
"text/vtt": ".vtt",
}
def _calculate_file_id(data: bytes, mimetype: str, filename: str | None = None) -> str:
"""Hash data, mimetype, and an optional filename to generate a stable file ID.
Parameters
----------
data
Content of in-memory file in bytes. Other types will throw TypeError.
mimetype
Any string. Will be converted to bytes and used to compute a hash.
filename
Any string. Will be converted to bytes and used to compute a hash.
"""
filehash = hashlib.new("sha224", usedforsecurity=False)
filehash.update(data)
filehash.update(bytes(mimetype.encode()))
if filename is not None:
filehash.update(bytes(filename.encode()))
return filehash.hexdigest()
def get_extension_for_mimetype(mimetype: str) -> str:
if mimetype in PREFERRED_MIMETYPE_EXTENSION_MAP:
return PREFERRED_MIMETYPE_EXTENSION_MAP[mimetype]
extension = mimetypes.guess_extension(mimetype, strict=False)
if extension is None:
return ""
return extension
class MemoryFile(NamedTuple):
"""A MediaFile stored in memory."""
content: bytes
mimetype: str
kind: MediaFileKind
filename: str | None
@property
def content_size(self) -> int:
return len(self.content)
class MemoryMediaFileStorage(MediaFileStorage, CacheStatsProvider):
def __init__(self, media_endpoint: str):
"""Create a new MemoryMediaFileStorage instance.
Parameters
----------
media_endpoint
The name of the local endpoint that media is served from.
This endpoint should start with a forward-slash (e.g. "/media").
"""
self._files_by_id: dict[str, MemoryFile] = {}
self._media_endpoint = media_endpoint
def load_and_get_id(
self,
path_or_data: str | bytes,
mimetype: str,
kind: MediaFileKind,
filename: str | None = None,
) -> str:
"""Add a file to the manager and return its ID."""
file_data: bytes
if isinstance(path_or_data, str):
file_data = self._read_file(path_or_data)
else:
file_data = path_or_data
# Because our file_ids are stable, if we already have a file with the
# given ID, we don't need to create a new one.
file_id = _calculate_file_id(file_data, mimetype, filename)
if file_id not in self._files_by_id:
_LOGGER.debug("Adding media file %s", file_id)
media_file = MemoryFile(
content=file_data, mimetype=mimetype, kind=kind, filename=filename
)
self._files_by_id[file_id] = media_file
return file_id
def get_file(self, filename: str) -> MemoryFile:
"""Return the MemoryFile with the given filename. Filenames are of the
form "file_id.extension". (Note that this is *not* the optional
user-specified filename for download files.).
Raises a MediaFileStorageError if no such file exists.
"""
file_id = os.path.splitext(filename)[0]
try:
return self._files_by_id[file_id]
except KeyError as e:
raise MediaFileStorageError(
f"Bad filename '{filename}'. (No media file with id '{file_id}')"
) from e
def get_url(self, file_id: str) -> str:
"""Get a URL for a given media file. Raise a MediaFileStorageError if
no such file exists.
"""
media_file = self.get_file(file_id)
extension = get_extension_for_mimetype(media_file.mimetype)
return f"{self._media_endpoint}/{file_id}{extension}"
def delete_file(self, file_id: str) -> None:
"""Delete the file with the given ID."""
# We swallow KeyErrors here - it's not an error to delete a file
# that doesn't exist.
with contextlib.suppress(KeyError):
del self._files_by_id[file_id]
def _read_file(self, filename: str) -> bytes:
"""Read a file into memory. Raise MediaFileStorageError if we can't."""
try:
with open(filename, "rb") as f:
return f.read()
except Exception as ex:
raise MediaFileStorageError(f"Error opening '{filename}'") from ex
def get_stats(self) -> list[CacheStat]:
# We operate on a copy of our dict, to avoid race conditions
# with other threads that may be manipulating the cache.
files_by_id = self._files_by_id.copy()
stats: list[CacheStat] = [
CacheStat(
category_name="st_memory_media_file_storage",
cache_name="",
byte_length=len(file.content),
)
for _, file in files_by_id.items()
]
return group_stats(stats)

View File

@@ -0,0 +1,77 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from cachetools import TTLCache
from streamlit.runtime.session_manager import SessionInfo, SessionStorage
if TYPE_CHECKING:
from collections.abc import MutableMapping
class MemorySessionStorage(SessionStorage):
"""A SessionStorage that stores sessions in memory.
At most maxsize sessions are stored with a TTL of ttl seconds. This class is really
just a thin wrapper around cachetools.TTLCache that complies with the SessionStorage
protocol.
"""
# NOTE: The defaults for maxsize and ttl are chosen arbitrarily for now. These
# numbers are reasonable as the main problems we're trying to solve at the moment are
# caused by transient disconnects that are usually just short network blips. In the
# future, we may want to increase both to support use cases such as saving state for
# much longer periods of time. For example, we may want session state to persist if
# a user closes their laptop lid and comes back to an app hours later.
def __init__(
self,
maxsize: int = 128,
ttl_seconds: int = 2 * 60, # 2 minutes
) -> None:
"""Instantiate a new MemorySessionStorage.
Parameters
----------
maxsize
The maximum number of sessions we allow to be stored in this
MemorySessionStorage. If an entry needs to be removed because we have
exceeded this number, either
- an expired entry is removed, or
- the least recently used entry is removed (if no entries have expired).
ttl_seconds
The time in seconds for an entry added to a MemorySessionStorage to live.
After this amount of time has passed for a given entry, it becomes
inaccessible and will be removed eventually.
"""
self._cache: MutableMapping[str, SessionInfo] = TTLCache(
maxsize=maxsize, ttl=ttl_seconds
)
def get(self, session_id: str) -> SessionInfo | None:
return self._cache.get(session_id, None)
def save(self, session_info: SessionInfo) -> None:
self._cache[session_info.session.id] = session_info
def delete(self, session_id: str) -> None:
del self._cache[session_id]
def list(self) -> list[SessionInfo]:
return list(self._cache.values())

View File

@@ -0,0 +1,138 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import uuid
from collections import defaultdict
from typing import TYPE_CHECKING
from streamlit import util
from streamlit.runtime.stats import CacheStat, group_stats
from streamlit.runtime.uploaded_file_manager import (
UploadedFileManager,
UploadedFileRec,
UploadFileUrlInfo,
)
if TYPE_CHECKING:
from collections.abc import Sequence
class MemoryUploadedFileManager(UploadedFileManager):
"""Holds files uploaded by users of the running Streamlit app.
This class can be used safely from multiple threads simultaneously.
"""
def __init__(self, upload_endpoint: str):
self.file_storage: dict[str, dict[str, UploadedFileRec]] = defaultdict(dict)
self.endpoint = upload_endpoint
def get_files(
self, session_id: str, file_ids: Sequence[str]
) -> list[UploadedFileRec]:
"""Return a list of UploadedFileRec for a given sequence of file_ids.
Parameters
----------
session_id
The ID of the session that owns the files.
file_ids
The sequence of ids associated with files to retrieve.
Returns
-------
List[UploadedFileRec]
A list of URL UploadedFileRec instances, each instance contains information
about uploaded file.
"""
session_storage = self.file_storage[session_id]
file_recs = []
for file_id in file_ids:
file_rec = session_storage.get(file_id, None)
if file_rec is not None:
file_recs.append(file_rec)
return file_recs
def remove_session_files(self, session_id: str) -> None:
"""Remove all files associated with a given session."""
self.file_storage.pop(session_id, None)
def __repr__(self) -> str:
return util.repr_(self)
def add_file(
self,
session_id: str,
file: UploadedFileRec,
) -> None:
"""
Safe to call from any thread.
Parameters
----------
session_id
The ID of the session that owns the file.
file
The file to add.
"""
self.file_storage[session_id][file.file_id] = file
def remove_file(self, session_id, file_id):
"""Remove file with given file_id associated with a given session."""
session_storage = self.file_storage[session_id]
session_storage.pop(file_id, None)
def get_upload_urls(
self, session_id: str, file_names: Sequence[str]
) -> list[UploadFileUrlInfo]:
"""Return a list of UploadFileUrlInfo for a given sequence of file_names."""
result = []
for _ in file_names:
file_id = str(uuid.uuid4())
result.append(
UploadFileUrlInfo(
file_id=file_id,
upload_url=f"{self.endpoint}/{session_id}/{file_id}",
delete_url=f"{self.endpoint}/{session_id}/{file_id}",
)
)
return result
def get_stats(self) -> list[CacheStat]:
"""Return the manager's CacheStats.
Safe to call from any thread.
"""
# Flatten all files into a single list
all_files: list[UploadedFileRec] = []
# Make copy of self.file_storage for thread safety, to be sure
# that main storage won't be changed form other thread
file_storage_copy = self.file_storage.copy()
for session_storage in file_storage_copy.values():
all_files.extend(session_storage.values())
stats: list[CacheStat] = [
CacheStat(
category_name="UploadedFileManager",
cache_name="",
byte_length=len(file.data),
)
for file in all_files
]
return group_stats(stats)

View File

@@ -0,0 +1,486 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import inspect
import os
import sys
import threading
import time
import uuid
from collections.abc import Sized
from functools import wraps
from typing import Any, Callable, Final, TypeVar, cast, overload
from streamlit import config, util
from streamlit.logger import get_logger
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.proto.PageProfile_pb2 import Argument, Command
from streamlit.runtime.scriptrunner_utils.exceptions import RerunException
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx
_LOGGER: Final = get_logger(__name__)
# Limit the number of commands to keep the page profile message small
_MAX_TRACKED_COMMANDS: Final = 200
# Only track a maximum of 25 uses per unique command since some apps use
# commands excessively (e.g. calling add_rows thousands of times in one rerun)
# making the page profile useless.
_MAX_TRACKED_PER_COMMAND: Final = 25
# A mapping to convert from the actual name to preferred/shorter representations
_OBJECT_NAME_MAPPING: Final = {
"streamlit.delta_generator.DeltaGenerator": "DG",
"pandas.core.frame.DataFrame": "DataFrame",
"plotly.graph_objs._figure.Figure": "PlotlyFigure",
"bokeh.plotting.figure.Figure": "BokehFigure",
"matplotlib.figure.Figure": "MatplotlibFigure",
"pandas.io.formats.style.Styler": "PandasStyler",
"pandas.core.indexes.base.Index": "PandasIndex",
"pandas.core.series.Series": "PandasSeries",
"streamlit.connections.snowpark_connection.SnowparkConnection": "SnowparkConnection",
"streamlit.connections.sql_connection.SQLConnection": "SQLConnection",
}
# A list of dependencies to check for attribution
_ATTRIBUTIONS_TO_CHECK: Final = [
# DB Clients:
"pymysql",
"MySQLdb",
"mysql",
"pymongo",
"ibis",
"boto3",
"psycopg2",
"psycopg3",
"sqlalchemy",
"elasticsearch",
"pyodbc",
"pymssql",
"cassandra",
"azure",
"redis",
"sqlite3",
"neo4j",
"duckdb",
"opensearchpy",
"supabase",
# Dataframe Libraries:
"polars",
"dask",
"vaex",
"modin",
"pyspark",
"cudf",
"xarray",
"ray",
"geopandas",
"mars",
"tables",
"zarr",
"datasets",
# ML & LLM Tools:
"mistralai",
"openai",
"langchain",
"llama_index",
"llama_cpp",
"anthropic",
"pyllamacpp",
"cohere",
"transformers",
"nomic",
"diffusers",
"semantic_kernel",
"replicate",
"huggingface_hub",
"wandb",
"torch",
"tensorflow",
"trubrics",
"comet_ml",
"clarifai",
"reka",
"hegel",
"fastchat",
"assemblyai",
"openllm",
"embedchain",
"haystack",
"vllm",
"alpa",
"jinaai",
"guidance",
"litellm",
"comet_llm",
"instructor",
"xgboost",
"lightgbm",
"catboost",
"sklearn",
# Workflow Tools:
"prefect",
"luigi",
"airflow",
"dagster",
# Vector Stores:
"pgvector",
"faiss",
"annoy",
"pinecone",
"chromadb",
"weaviate",
"qdrant_client",
"pymilvus",
"lancedb",
# Others:
"snowflake",
"streamlit_extras",
"streamlit_pydantic",
"pydantic",
"plost",
"authlib",
]
_ETC_MACHINE_ID_PATH = "/etc/machine-id"
_DBUS_MACHINE_ID_PATH = "/var/lib/dbus/machine-id"
def _get_machine_id_v3() -> str:
"""Get the machine ID.
This is a unique identifier for a user for tracking metrics,
that is broken in different ways in some Linux distros and Docker images.
- at times just a hash of '', which means many machines map to the same ID
- at times a hash of the same string, when running in a Docker container
"""
machine_id = str(uuid.getnode())
if os.path.isfile(_ETC_MACHINE_ID_PATH):
with open(_ETC_MACHINE_ID_PATH) as f:
machine_id = f.read()
elif os.path.isfile(_DBUS_MACHINE_ID_PATH):
with open(_DBUS_MACHINE_ID_PATH) as f:
machine_id = f.read()
return machine_id
class Installation:
_instance_lock = threading.Lock()
_instance: Installation | None = None
@classmethod
def instance(cls) -> Installation:
"""Returns the singleton Installation."""
# We use a double-checked locking optimization to avoid the overhead
# of acquiring the lock in the common case:
# https://en.wikipedia.org/wiki/Double-checked_locking
if cls._instance is None:
with cls._instance_lock:
if cls._instance is None:
cls._instance = Installation()
return cls._instance
def __init__(self):
self.installation_id_v3 = str(
uuid.uuid5(uuid.NAMESPACE_DNS, _get_machine_id_v3())
)
def __repr__(self) -> str:
return util.repr_(self)
@property
def installation_id(self):
return self.installation_id_v3
def _get_type_name(obj: object) -> str:
"""Get a simplified name for the type of the given object."""
with contextlib.suppress(Exception):
obj_type = obj if inspect.isclass(obj) else type(obj)
type_name = "unknown"
if hasattr(obj_type, "__qualname__"):
type_name = obj_type.__qualname__
elif hasattr(obj_type, "__name__"):
type_name = obj_type.__name__
if obj_type.__module__ != "builtins":
# Add the full module path
type_name = f"{obj_type.__module__}.{type_name}"
if type_name in _OBJECT_NAME_MAPPING:
type_name = _OBJECT_NAME_MAPPING[type_name]
return type_name
return "failed"
def _get_top_level_module(func: Callable[..., Any]) -> str:
"""Get the top level module for the given function."""
module = inspect.getmodule(func)
if module is None or not module.__name__:
return "unknown"
return module.__name__.split(".")[0]
def _get_arg_metadata(arg: object) -> str | None:
"""Get metadata information related to the value of the given object."""
with contextlib.suppress(Exception):
if isinstance(arg, (bool)):
return f"val:{arg}"
if isinstance(arg, Sized):
return f"len:{len(arg)}"
return None
def _get_command_telemetry(
_command_func: Callable[..., Any], _command_name: str, *args, **kwargs
) -> Command:
"""Get telemetry information for the given callable and its arguments."""
arg_keywords = inspect.getfullargspec(_command_func).args
self_arg: Any | None = None
arguments: list[Argument] = []
is_method = inspect.ismethod(_command_func)
name = _command_name
for i, arg in enumerate(args):
pos = i
if is_method:
# If func is a method, ignore the first argument (self)
i = i + 1
keyword = arg_keywords[i] if len(arg_keywords) > i else f"{i}"
if keyword == "self":
self_arg = arg
continue
argument = Argument(k=keyword, t=_get_type_name(arg), p=pos)
arg_metadata = _get_arg_metadata(arg)
if arg_metadata:
argument.m = arg_metadata
arguments.append(argument)
for kwarg, kwarg_value in kwargs.items():
argument = Argument(k=kwarg, t=_get_type_name(kwarg_value))
arg_metadata = _get_arg_metadata(kwarg_value)
if arg_metadata:
argument.m = arg_metadata
arguments.append(argument)
top_level_module = _get_top_level_module(_command_func)
if top_level_module != "streamlit":
# If the gather_metrics decorator is used outside of streamlit library
# we enforce a prefix to be added to the tracked command:
name = f"external:{top_level_module}:{name}"
if (
name == "create_instance"
and self_arg
and hasattr(self_arg, "name")
and self_arg.name
):
name = f"component:{self_arg.name}"
return Command(name=name, args=arguments)
def to_microseconds(seconds: float) -> int:
"""Convert seconds into microseconds."""
return int(seconds * 1_000_000)
F = TypeVar("F", bound=Callable[..., Any])
@overload
def gather_metrics(
name: str,
func: F,
) -> F: ...
@overload
def gather_metrics(
name: str,
func: None = None,
) -> Callable[[F], F]: ...
def gather_metrics(name: str, func: F | None = None) -> Callable[[F], F] | F:
"""Function decorator to add telemetry tracking to commands.
Parameters
----------
func : callable
The function to track for telemetry.
name : str or None
Overwrite the function name with a custom name that is used for telemetry tracking.
Example
-------
>>> @st.gather_metrics
... def my_command(url):
... return url
>>> @st.gather_metrics(name="custom_name")
... def my_command(url):
... return url
"""
if not name:
_LOGGER.warning("gather_metrics: name is empty")
name = "undefined"
if func is None:
# Support passing the params via function decorator
def wrapper(f: F) -> F:
return gather_metrics(
name=name,
func=f,
)
return wrapper
else:
# To make mypy type narrow F | None -> F
non_optional_func = func
@wraps(non_optional_func)
def wrapped_func(*args, **kwargs):
from timeit import default_timer as timer
exec_start = timer()
ctx = get_script_run_ctx(suppress_warning=True)
tracking_activated = (
ctx is not None
and ctx.gather_usage_stats
and not ctx.command_tracking_deactivated
and len(ctx.tracked_commands)
< _MAX_TRACKED_COMMANDS # Prevent too much memory usage
)
command_telemetry: Command | None = None
# This flag is needed to make sure that only the command (the outermost command)
# that deactivated tracking (via ctx.command_tracking_deactivated) is able to reset it
# again. This is important to prevent nested commands from reactivating tracking.
# At this point, we don't know yet if the command will deactivated tracking.
has_set_command_tracking_deactivated = False
if ctx and tracking_activated:
try:
command_telemetry = _get_command_telemetry(
non_optional_func, name, *args, **kwargs
)
if (
command_telemetry.name not in ctx.tracked_commands_counter
or ctx.tracked_commands_counter[command_telemetry.name]
< _MAX_TRACKED_PER_COMMAND
):
ctx.tracked_commands.append(command_telemetry)
ctx.tracked_commands_counter.update([command_telemetry.name])
# Deactivate tracking to prevent calls inside already tracked commands
ctx.command_tracking_deactivated = True
# The ctx.command_tracking_deactivated flag was set to True,
# we also need to set has_set_command_tracking_deactivated to True
# to make sure that this command is able to reset it again.
has_set_command_tracking_deactivated = True
except Exception as ex:
# Always capture all exceptions since we want to make sure that
# the telemetry never causes any issues.
_LOGGER.debug("Failed to collect command telemetry", exc_info=ex)
try:
result = non_optional_func(*args, **kwargs)
except RerunException as ex:
# Duplicated from below, because static analysis tools get confused
# by deferring the rethrow.
if tracking_activated and command_telemetry:
command_telemetry.time = to_microseconds(timer() - exec_start)
raise ex
finally:
# Activate tracking again if command executes without any exceptions
# we only want to do that if this command has set the
# flag to deactivate tracking.
if ctx and has_set_command_tracking_deactivated:
ctx.command_tracking_deactivated = False
if tracking_activated and command_telemetry:
# Set the execution time to the measured value
command_telemetry.time = to_microseconds(timer() - exec_start)
return result
with contextlib.suppress(AttributeError):
# Make this a well-behaved decorator by preserving important function
# attributes.
wrapped_func.__dict__.update(non_optional_func.__dict__)
wrapped_func.__signature__ = inspect.signature(non_optional_func) # type: ignore
return cast("F", wrapped_func)
def create_page_profile_message(
commands: list[Command],
exec_time: int,
prep_time: int,
uncaught_exception: str | None = None,
) -> ForwardMsg:
"""Create and return the full PageProfile ForwardMsg."""
msg = ForwardMsg()
page_profile = msg.page_profile
page_profile.commands.extend(commands)
page_profile.exec_time = exec_time
page_profile.prep_time = prep_time
page_profile.headless = config.get_option("server.headless")
# Collect all config options that have been manually set
config_options: set[str] = set()
if config._config_options:
for option_name in config._config_options.keys():
if not config.is_manually_set(option_name):
# We only care about manually defined options
continue
config_option = config._config_options[option_name]
if config_option.is_default:
option_name = f"{option_name}:default"
config_options.add(option_name)
page_profile.config.extend(config_options)
# Check the predefined set of modules for attribution
attributions: set[str] = {
attribution
for attribution in _ATTRIBUTIONS_TO_CHECK
if attribution in sys.modules
}
page_profile.os = str(sys.platform)
page_profile.timezone = str(time.tzname)
page_profile.attributions.extend(attributions)
if uncaught_exception:
page_profile.uncaught_exception = uncaught_exception
if ctx := get_script_run_ctx():
page_profile.is_fragment_run = bool(ctx.fragment_ids_this_run)
return msg

View File

@@ -0,0 +1,162 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, Any
from streamlit.util import calc_md5
if TYPE_CHECKING:
from streamlit.runtime.scriptrunner.script_cache import ScriptCache
from streamlit.source_util import PageHash, PageInfo, PageName, ScriptPath
class PagesManager:
"""
PagesManager is responsible for managing the set of pages that make up
the entire application. At the start we assume the main script is the
only page. As the script runs, the main script can call `st.navigation`
to set the set of pages that make up the app.
"""
uses_pages_directory: bool | None = None
def __init__(
self,
main_script_path: ScriptPath,
script_cache: ScriptCache | None = None,
**kwargs,
):
self._main_script_path = main_script_path
self._main_script_hash: PageHash = calc_md5(main_script_path)
self._script_cache = script_cache
self._intended_page_script_hash: PageHash | None = None
self._intended_page_name: PageName | None = None
self._current_page_script_hash: PageHash = ""
self._pages: dict[PageHash, PageInfo] | None = None
# A relic of v1 of Multipage apps, we performed special handling
# for apps with a pages directory. We will keep this flag around
# for now to maintain the behavior for apps that were created with
# the pages directory feature.
#
# NOTE: we will update the feature if the flag has not been set
# this means that if users use v2 behavior, the flag will
# always be set to False
if PagesManager.uses_pages_directory is None:
PagesManager.uses_pages_directory = Path(
self.main_script_parent / "pages"
).exists()
@property
def main_script_path(self) -> ScriptPath:
return self._main_script_path
@property
def main_script_parent(self) -> Path:
return Path(self._main_script_path).parent
@property
def main_script_hash(self) -> PageHash:
return self._main_script_hash
@property
def current_page_script_hash(self) -> PageHash:
return self._current_page_script_hash
@property
def intended_page_name(self) -> PageName | None:
return self._intended_page_name
@property
def intended_page_script_hash(self) -> PageHash | None:
return self._intended_page_script_hash
def set_current_page_script_hash(self, page_script_hash: PageHash) -> None:
self._current_page_script_hash = page_script_hash
def get_main_page(self) -> PageInfo:
return {
"script_path": self._main_script_path,
"page_script_hash": self._main_script_hash,
}
def set_script_intent(
self, page_script_hash: PageHash, page_name: PageName
) -> None:
self._intended_page_script_hash = page_script_hash
self._intended_page_name = page_name
def get_initial_active_script(
self, page_script_hash: PageHash, page_name: PageName
) -> PageInfo | None:
return {
# We always run the main script in V2 as it's the common code
"script_path": self.main_script_path,
"page_script_hash": page_script_hash
or self.main_script_hash, # Default Hash
}
def get_pages(self) -> dict[PageHash, PageInfo]:
# If pages are not set, provide the common page info where
# - the main script path is the executing script to start
# - the page script hash and name reflects the intended page requested
return self._pages or {
self.main_script_hash: {
"page_script_hash": self.intended_page_script_hash or "",
"page_name": self.intended_page_name or "",
"icon": "",
"script_path": self.main_script_path,
}
}
def set_pages(self, pages: dict[PageHash, PageInfo]) -> None:
self._pages = pages
def get_page_script(self, fallback_page_hash: PageHash = "") -> PageInfo | None:
if self._pages is None:
return None
if self.intended_page_script_hash:
# We assume that if initial page hash is specified, that a page should
# exist, so we check out the page script hash or the default page hash
# as a backup
return self._pages.get(
self.intended_page_script_hash,
self._pages.get(fallback_page_hash, None),
)
elif self.intended_page_name:
# If a user navigates directly to a non-main page of an app, the
# the page name can identify the page script to run
return next(
filter(
# There seems to be this weird bug with mypy where it
# thinks that p can be None (which is impossible given the
# types of pages), so we add `p and` at the beginning of
# the predicate to circumvent this.
lambda p: p and (p["url_pathname"] == self.intended_page_name),
self._pages.values(),
),
None,
)
return self._pages.get(fallback_page_hash, None)
def get_page_script_byte_code(self, script_path: str) -> Any:
if self._script_cache is None:
# Returning an empty string for an empty script
return ""
return self._script_cache.get_bytecode(script_path)

View File

@@ -0,0 +1,792 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import time
import traceback
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Final, NamedTuple
from streamlit import config
from streamlit.components.lib.local_component_registry import LocalComponentRegistry
from streamlit.logger import get_logger
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime.app_session import AppSession
from streamlit.runtime.caching import (
get_data_cache_stats_provider,
get_resource_cache_stats_provider,
)
from streamlit.runtime.caching.storage.local_disk_cache_storage import (
LocalDiskCacheStorageManager,
)
from streamlit.runtime.forward_msg_cache import (
ForwardMsgCache,
create_reference_msg,
populate_hash_if_needed,
)
from streamlit.runtime.media_file_manager import MediaFileManager
from streamlit.runtime.memory_session_storage import MemorySessionStorage
from streamlit.runtime.runtime_util import is_cacheable_msg
from streamlit.runtime.script_data import ScriptData
from streamlit.runtime.scriptrunner.script_cache import ScriptCache
from streamlit.runtime.session_manager import (
ActiveSessionInfo,
SessionClient,
SessionClientDisconnectedError,
SessionManager,
SessionStorage,
)
from streamlit.runtime.state import (
SCRIPT_RUN_WITHOUT_ERRORS_KEY,
SessionStateStatProvider,
)
from streamlit.runtime.stats import StatsManager
from streamlit.runtime.websocket_session_manager import WebsocketSessionManager
if TYPE_CHECKING:
from collections.abc import Awaitable
from streamlit.components.types.base_component_registry import BaseComponentRegistry
from streamlit.proto.BackMsg_pb2 import BackMsg
from streamlit.runtime.caching.storage import CacheStorageManager
from streamlit.runtime.media_file_storage import MediaFileStorage
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
# Wait for the script run result for 60s and if no result is available give up
SCRIPT_RUN_CHECK_TIMEOUT: Final = 60
_LOGGER: Final = get_logger(__name__)
class RuntimeStoppedError(Exception):
"""Raised by operations on a Runtime instance that is stopped."""
@dataclass(frozen=True)
class RuntimeConfig:
"""Config options for StreamlitRuntime."""
# The filesystem path of the Streamlit script to run.
script_path: str
# DEPRECATED: We need to keep this field around for compatibility reasons, but we no
# longer use this anywhere.
command_line: str | None
# The storage backend for Streamlit's MediaFileManager.
media_file_storage: MediaFileStorage
# The upload file manager
uploaded_file_manager: UploadedFileManager
# The cache storage backend for Streamlit's st.cache_data.
cache_storage_manager: CacheStorageManager = field(
default_factory=LocalDiskCacheStorageManager
)
# The ComponentRegistry instance to use.
component_registry: BaseComponentRegistry = field(
default_factory=LocalComponentRegistry
)
# The SessionManager class to be used.
session_manager_class: type[SessionManager] = WebsocketSessionManager
# The SessionStorage instance for the SessionManager to use.
session_storage: SessionStorage = field(default_factory=MemorySessionStorage)
# True if the command used to start Streamlit was `streamlit hello`.
is_hello: bool = False
# TODO(vdonato): Eventually add a new fragment_storage_class field enabling the code
# creating a new Streamlit Runtime to configure the FragmentStorage instances
# created by each new AppSession. We choose not to do this for now to avoid adding
# additional complexity to RuntimeConfig/SessionManager/etc when it's unlikely
# we'll have a custom implementation of this class anytime soon.
class RuntimeState(Enum):
INITIAL = "INITIAL"
NO_SESSIONS_CONNECTED = "NO_SESSIONS_CONNECTED"
ONE_OR_MORE_SESSIONS_CONNECTED = "ONE_OR_MORE_SESSIONS_CONNECTED"
STOPPING = "STOPPING"
STOPPED = "STOPPED"
class AsyncObjects(NamedTuple):
"""Container for all asyncio objects that Runtime manages.
These cannot be initialized until the Runtime's eventloop is assigned.
"""
# The eventloop that Runtime is running on.
eventloop: asyncio.AbstractEventLoop
# Set after Runtime.stop() is called. Never cleared.
must_stop: asyncio.Event
# Set when a client connects; cleared when we have no connected clients.
has_connection: asyncio.Event
# Set after a ForwardMsg is enqueued; cleared when we flush ForwardMsgs.
need_send_data: asyncio.Event
# Completed when the Runtime has started.
started: asyncio.Future[None]
# Completed when the Runtime has stopped.
stopped: asyncio.Future[None]
class Runtime:
_instance: Runtime | None = None
@classmethod
def instance(cls) -> Runtime:
"""Return the singleton Runtime instance. Raise an Error if the
Runtime hasn't been created yet.
"""
if cls._instance is None:
raise RuntimeError("Runtime hasn't been created!")
return cls._instance
@classmethod
def exists(cls) -> bool:
"""True if the singleton Runtime instance has been created.
When a Streamlit app is running in "raw mode" - that is, when the
app is run via `python app.py` instead of `streamlit run app.py` -
the Runtime will not exist, and various Streamlit functions need
to adapt.
"""
return cls._instance is not None
def __init__(self, config: RuntimeConfig):
"""Create a Runtime instance. It won't be started yet.
Runtime is *not* thread-safe. Its public methods are generally
safe to call only on the same thread that its event loop runs on.
Parameters
----------
config
Config options.
"""
if Runtime._instance is not None:
raise RuntimeError("Runtime instance already exists!")
Runtime._instance = self
# Will be created when we start.
self._async_objs: AsyncObjects | None = None
# The task that runs our main loop. We need to save a reference
# to it so that it doesn't get garbage collected while running.
self._loop_coroutine_task: asyncio.Task[None] | None = None
self._main_script_path = config.script_path
self._is_hello = config.is_hello
self._state = RuntimeState.INITIAL
# Initialize managers
self._component_registry = config.component_registry
self._message_cache = ForwardMsgCache()
self._uploaded_file_mgr = config.uploaded_file_manager
self._media_file_mgr = MediaFileManager(storage=config.media_file_storage)
self._cache_storage_manager = config.cache_storage_manager
self._script_cache = ScriptCache()
self._session_mgr = config.session_manager_class(
session_storage=config.session_storage,
uploaded_file_manager=self._uploaded_file_mgr,
script_cache=self._script_cache,
message_enqueued_callback=self._enqueued_some_message,
)
self._stats_mgr = StatsManager()
self._stats_mgr.register_provider(get_data_cache_stats_provider())
self._stats_mgr.register_provider(get_resource_cache_stats_provider())
self._stats_mgr.register_provider(self._message_cache)
self._stats_mgr.register_provider(self._uploaded_file_mgr)
self._stats_mgr.register_provider(SessionStateStatProvider(self._session_mgr))
@property
def state(self) -> RuntimeState:
return self._state
@property
def component_registry(self) -> BaseComponentRegistry:
return self._component_registry
@property
def message_cache(self) -> ForwardMsgCache:
return self._message_cache
@property
def uploaded_file_mgr(self) -> UploadedFileManager:
return self._uploaded_file_mgr
@property
def cache_storage_manager(self) -> CacheStorageManager:
return self._cache_storage_manager
@property
def media_file_mgr(self) -> MediaFileManager:
return self._media_file_mgr
@property
def stats_mgr(self) -> StatsManager:
return self._stats_mgr
@property
def stopped(self) -> Awaitable[None]:
"""A Future that completes when the Runtime's run loop has exited."""
return self._get_async_objs().stopped
# NOTE: A few Runtime methods listed as threadsafe (get_client and
# is_active_session) currently rely on the implementation detail that
# WebsocketSessionManager's get_active_session_info and is_active_session methods
# happen to be threadsafe. This may change with future SessionManager implementations,
# at which point we'll need to formalize our thread safety rules for each
# SessionManager method.
def get_client(self, session_id: str) -> SessionClient | None:
"""Get the SessionClient for the given session_id, or None
if no such session exists.
Notes
-----
Threading: SAFE. May be called on any thread.
"""
session_info = self._session_mgr.get_active_session_info(session_id)
if session_info is None:
return None
return session_info.client
def clear_user_info_for_session(self, session_id: str) -> None:
"""Clear the user_info for the given session_id.
Notes
-----
Threading: SAFE. May be called on any thread.
"""
session_info = self._session_mgr.get_session_info(session_id)
if session_info is not None:
session_info.session.clear_user_info()
async def start(self) -> None:
"""Start the runtime. This must be called only once, before
any other functions are called.
When this coroutine returns, Streamlit is ready to accept new sessions.
Notes
-----
Threading: UNSAFE. Must be called on the eventloop thread.
"""
# Create our AsyncObjects. We need to have a running eventloop to
# instantiate our various synchronization primitives.
async_objs = AsyncObjects(
eventloop=asyncio.get_running_loop(),
must_stop=asyncio.Event(),
has_connection=asyncio.Event(),
need_send_data=asyncio.Event(),
started=asyncio.Future(),
stopped=asyncio.Future(),
)
self._async_objs = async_objs
self._loop_coroutine_task = asyncio.create_task(
self._loop_coroutine(), name="Runtime.loop_coroutine"
)
await async_objs.started
def stop(self) -> None:
"""Request that Streamlit close all sessions and stop running.
Note that Streamlit won't stop running immediately.
Notes
-----
Threading: SAFE. May be called from any thread.
"""
async_objs = self._get_async_objs()
def stop_on_eventloop():
if self._state in (RuntimeState.STOPPING, RuntimeState.STOPPED):
return
_LOGGER.debug("Runtime stopping...")
self._set_state(RuntimeState.STOPPING)
async_objs.must_stop.set()
async_objs.eventloop.call_soon_threadsafe(stop_on_eventloop)
def is_active_session(self, session_id: str) -> bool:
"""True if the session_id belongs to an active session.
Notes
-----
Threading: SAFE. May be called on any thread.
"""
return self._session_mgr.is_active_session(session_id)
def connect_session(
self,
client: SessionClient,
user_info: dict[str, str | bool | None],
existing_session_id: str | None = None,
session_id_override: str | None = None,
) -> str:
"""Create a new session (or connect to an existing one) and return its unique ID.
Parameters
----------
client
A concrete SessionClient implementation for communicating with
the session's client.
user_info
A dict that contains information about the session's user. For now,
it only (optionally) contains the user's email address.
{
"email": "example@example.com"
}
existing_session_id
The ID of an existing session to reconnect to. If one is not provided, a new
session is created. Note that whether the Runtime's SessionManager supports
reconnecting to an existing session depends on the SessionManager that this
runtime is configured with.
session_id_override
The ID to assign to a new session being created with this method. Setting
this can be useful when the service that a Streamlit Runtime is running in
wants to tie the lifecycle of a Streamlit session to some other session-like
object that it manages. Only one of existing_session_id and
session_id_override should be set.
Returns
-------
str
The session's unique string ID.
Notes
-----
Threading: UNSAFE. Must be called on the eventloop thread.
"""
assert not (existing_session_id and session_id_override), (
"Only one of existing_session_id and session_id_override should be set!"
)
if self._state in (RuntimeState.STOPPING, RuntimeState.STOPPED):
raise RuntimeStoppedError(f"Can't connect_session (state={self._state})")
session_id = self._session_mgr.connect_session(
client=client,
script_data=ScriptData(self._main_script_path, self._is_hello),
user_info=user_info,
existing_session_id=existing_session_id,
session_id_override=session_id_override,
)
self._set_state(RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED)
self._get_async_objs().has_connection.set()
return session_id
def create_session(
self,
client: SessionClient,
user_info: dict[str, str | bool | None],
existing_session_id: str | None = None,
session_id_override: str | None = None,
) -> str:
"""Create a new session (or connect to an existing one) and return its unique ID.
Notes
-----
This method is simply an alias for connect_session added for backwards
compatibility.
"""
_LOGGER.warning("create_session is deprecated! Use connect_session instead.")
return self.connect_session(
client=client,
user_info=user_info,
existing_session_id=existing_session_id,
session_id_override=session_id_override,
)
def close_session(self, session_id: str) -> None:
"""Close and completely shut down a session.
This differs from disconnect_session in that it always completely shuts down a
session, permanently losing any associated state (session state, uploaded files,
etc.).
This function may be called multiple times for the same session,
which is not an error. (Subsequent calls just no-op.)
Parameters
----------
session_id
The session's unique ID.
Notes
-----
Threading: UNSAFE. Must be called on the eventloop thread.
"""
session_info = self._session_mgr.get_session_info(session_id)
if session_info:
self._message_cache.remove_refs_for_session(session_info.session)
self._session_mgr.close_session(session_id)
self._on_session_disconnected()
def disconnect_session(self, session_id: str) -> None:
"""Disconnect a session. It will stop producing ForwardMsgs.
Differs from close_session because disconnected sessions can be reconnected to
for a brief window (depending on the SessionManager/SessionStorage
implementations used by the runtime).
This function may be called multiple times for the same session,
which is not an error. (Subsequent calls just no-op.)
Parameters
----------
session_id
The session's unique ID.
Notes
-----
Threading: UNSAFE. Must be called on the eventloop thread.
"""
session_info = self._session_mgr.get_active_session_info(session_id)
if session_info:
# NOTE: Ideally, we'd like to keep ForwardMsgCache refs for a session around
# when a session is disconnected (and defer their cleanup until the session
# is garbage collected), but this would be difficult to do as the
# ForwardMsgCache is not thread safe, and we have no guarantee that the
# garbage collector will only run on the eventloop thread. Because of this,
# we clean up refs now and accept the risk that we're deleting cache entries
# that will be useful once the browser tab reconnects.
self._message_cache.remove_refs_for_session(session_info.session)
self._session_mgr.disconnect_session(session_id)
self._on_session_disconnected()
def handle_backmsg(self, session_id: str, msg: BackMsg) -> None:
"""Send a BackMsg to an active session.
Parameters
----------
session_id
The session's unique ID.
msg
The BackMsg to deliver to the session.
Notes
-----
Threading: UNSAFE. Must be called on the eventloop thread.
"""
if self._state in (RuntimeState.STOPPING, RuntimeState.STOPPED):
raise RuntimeStoppedError(f"Can't handle_backmsg (state={self._state})")
session_info = self._session_mgr.get_active_session_info(session_id)
if session_info is None:
_LOGGER.debug(
"Discarding BackMsg for disconnected session (id=%s)", session_id
)
return
session_info.session.handle_backmsg(msg)
def handle_backmsg_deserialization_exception(
self, session_id: str, exc: BaseException
) -> None:
"""Handle an Exception raised during deserialization of a BackMsg.
Parameters
----------
session_id
The session's unique ID.
exc
The Exception.
Notes
-----
Threading: UNSAFE. Must be called on the eventloop thread.
"""
if self._state in (RuntimeState.STOPPING, RuntimeState.STOPPED):
raise RuntimeStoppedError(
f"Can't handle_backmsg_deserialization_exception (state={self._state})"
)
session_info = self._session_mgr.get_active_session_info(session_id)
if session_info is None:
_LOGGER.debug(
"Discarding BackMsg Exception for disconnected session (id=%s)",
session_id,
)
return
session_info.session.handle_backmsg_exception(exc)
@property
async def is_ready_for_browser_connection(self) -> tuple[bool, str]:
if self._state not in (
RuntimeState.INITIAL,
RuntimeState.STOPPING,
RuntimeState.STOPPED,
):
return True, "ok"
return False, "unavailable"
async def does_script_run_without_error(self) -> tuple[bool, str]:
"""Load and execute the app's script to verify it runs without an error.
Returns
-------
(True, "ok") if the script completes without error, or (False, err_msg)
if the script raises an exception.
Notes
-----
Threading: UNSAFE. Must be called on the eventloop thread.
"""
# NOTE: We create an AppSession directly here instead of using the
# SessionManager intentionally. This isn't a "real" session and is only being
# used to test that the script runs without error.
session = AppSession(
script_data=ScriptData(self._main_script_path, self._is_hello),
uploaded_file_manager=self._uploaded_file_mgr,
script_cache=self._script_cache,
message_enqueued_callback=self._enqueued_some_message,
user_info={"email": "test@example.com"},
)
try:
session.request_rerun(None)
now = time.perf_counter()
while ( # noqa: ASYNC110
SCRIPT_RUN_WITHOUT_ERRORS_KEY not in session.session_state
and (time.perf_counter() - now) < SCRIPT_RUN_CHECK_TIMEOUT
):
await asyncio.sleep(0.1)
if SCRIPT_RUN_WITHOUT_ERRORS_KEY not in session.session_state:
return False, "timeout"
ok = session.session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY]
msg = "ok" if ok else "error"
return ok, msg
finally:
session.shutdown()
def _set_state(self, new_state: RuntimeState) -> None:
_LOGGER.debug("Runtime state: %s -> %s", self._state, new_state)
self._state = new_state
async def _loop_coroutine(self) -> None:
"""The main Runtime loop.
This function won't exit until `stop` is called.
Notes
-----
Threading: UNSAFE. Must be called on the eventloop thread.
"""
async_objs = self._get_async_objs()
try:
if self._state == RuntimeState.INITIAL:
self._set_state(RuntimeState.NO_SESSIONS_CONNECTED)
elif self._state == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED:
pass
else:
raise RuntimeError(f"Bad Runtime state at start: {self._state}")
# Signal that we're started and ready to accept sessions
async_objs.started.set_result(None)
while not async_objs.must_stop.is_set():
if self._state == RuntimeState.NO_SESSIONS_CONNECTED: # type: ignore[comparison-overlap]
# mypy 1.4 incorrectly thinks this if-clause is unreachable,
# because it thinks self._state must be INITIAL | ONE_OR_MORE_SESSIONS_CONNECTED.
# Wait for new websocket connections (new sessions):
_, pending_tasks = await asyncio.wait( # type: ignore[unreachable]
(
asyncio.create_task(async_objs.must_stop.wait()),
asyncio.create_task(async_objs.has_connection.wait()),
),
return_when=asyncio.FIRST_COMPLETED,
)
# Clean up pending tasks to avoid memory leaks
for task in pending_tasks:
task.cancel()
elif self._state == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED:
async_objs.need_send_data.clear()
for active_session_info in self._session_mgr.list_active_sessions():
msg_list = active_session_info.session.flush_browser_queue()
for msg in msg_list:
try:
self._send_message(active_session_info, msg)
except SessionClientDisconnectedError:
self._session_mgr.disconnect_session(
active_session_info.session.id
)
# Yield for a tick after sending a message.
await asyncio.sleep(0)
# Yield for a few milliseconds between session message
# flushing.
await asyncio.sleep(0.01)
else:
# Break out of the thread loop if we encounter any other state.
break
# Wait for new proto messages that need to be sent out:
_, pending_tasks = await asyncio.wait(
(
asyncio.create_task(async_objs.must_stop.wait()),
asyncio.create_task(async_objs.need_send_data.wait()),
),
return_when=asyncio.FIRST_COMPLETED,
)
# We need to cancel the pending tasks (the `must_stop` one in most situations).
# Otherwise, this would stack up one waiting task per loop
# (e.g. per forward message). These tasks cannot be garbage collected
# causing an increase in memory (-> memory leak).
for task in pending_tasks:
task.cancel()
# Shut down all AppSessions.
for session_info in self._session_mgr.list_sessions():
# NOTE: We want to fully shut down sessions when the runtime stops for
# now, but this may change in the future if/when our notion of a session
# is no longer so tightly coupled to a browser tab.
self._session_mgr.close_session(session_info.session.id)
self._set_state(RuntimeState.STOPPED)
async_objs.stopped.set_result(None)
except Exception as e:
async_objs.stopped.set_exception(e)
traceback.print_exc()
_LOGGER.info(
"""
Please report this bug at https://github.com/streamlit/streamlit/issues.
"""
)
def _send_message(self, session_info: ActiveSessionInfo, msg: ForwardMsg) -> None:
"""Send a message to a client.
If the client is likely to have already cached the message, we may
instead send a "reference" message that contains only the hash of the
message.
Parameters
----------
session_info : ActiveSessionInfo
The ActiveSessionInfo associated with websocket
msg : ForwardMsg
The message to send to the client
Notes
-----
Threading: UNSAFE. Must be called on the eventloop thread.
"""
msg.metadata.cacheable = is_cacheable_msg(msg)
msg_to_send = msg
if msg.metadata.cacheable:
populate_hash_if_needed(msg)
if self._message_cache.has_message_reference(
msg, session_info.session, session_info.script_run_count
):
# This session has probably cached this message. Send
# a reference instead.
_LOGGER.debug("Sending cached message ref (hash=%s)", msg.hash)
msg_to_send = create_reference_msg(msg)
# Cache the message so it can be referenced in the future.
# If the message is already cached, this will reset its
# age.
_LOGGER.debug("Caching message (hash=%s)", msg.hash)
self._message_cache.add_message(
msg, session_info.session, session_info.script_run_count
)
# If this was a `script_finished` message, we increment the
# script_run_count for this session, and update the cache
if msg.WhichOneof("type") == "script_finished" and (
msg.script_finished == ForwardMsg.FINISHED_SUCCESSFULLY
or (
config.get_option(
"global.includeFragmentRunsInForwardMessageCacheCount"
)
and msg.script_finished == ForwardMsg.FINISHED_FRAGMENT_RUN_SUCCESSFULLY
)
):
_LOGGER.debug(
"Script run finished successfully; "
"removing expired entries from MessageCache "
"(max_age=%s)",
config.get_option("global.maxCachedMessageAge"),
)
session_info.script_run_count += 1
self._message_cache.remove_expired_entries_for_session(
session_info.session, session_info.script_run_count
)
# Ship it off!
session_info.client.write_forward_msg(msg_to_send)
def _enqueued_some_message(self) -> None:
"""Callback called by AppSession after the AppSession has enqueued a
message. Sets the "needs_send_data" event, which causes our core
loop to wake up and flush client message queues.
Notes
-----
Threading: SAFE. May be called on any thread.
"""
async_objs = self._get_async_objs()
async_objs.eventloop.call_soon_threadsafe(async_objs.need_send_data.set)
def _get_async_objs(self) -> AsyncObjects:
"""Return our AsyncObjects instance. If the Runtime hasn't been
started, this will raise an error.
"""
if self._async_objs is None:
raise RuntimeError("Runtime hasn't started yet!")
return self._async_objs
def _on_session_disconnected(self) -> None:
"""Set the runtime state to NO_SESSIONS_CONNECTED if the last active
session was disconnected.
"""
if (
self._state == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED
and self._session_mgr.num_active_sessions() == 0
):
self._get_async_objs().has_connection.clear()
self._set_state(RuntimeState.NO_SESSIONS_CONNECTED)

View File

@@ -0,0 +1,106 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runtime-related utility functions."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from streamlit import config
from streamlit.errors import MarkdownFormattedException, StreamlitAPIException
from streamlit.runtime.forward_msg_cache import populate_hash_if_needed
if TYPE_CHECKING:
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
class MessageSizeError(MarkdownFormattedException):
"""Exception raised when a websocket message is larger than the configured limit."""
def __init__(self, failed_msg_str: Any):
msg = self._get_message(failed_msg_str)
super().__init__(msg)
def _get_message(self, failed_msg_str: Any) -> str:
# This needs to have zero indentation otherwise the markdown will render incorrectly.
return (
f"""
**Data of size {len(failed_msg_str) / 1e6:.1f} MB exceeds the message size limit of {get_max_message_size_bytes() / 1e6} MB.**
This is often caused by a large chart or dataframe. Please decrease the amount of data sent
to the browser, or increase the limit by setting the config option `server.maxMessageSize`.
[Click here to learn more about config options](https://docs.streamlit.io/develop/api-reference/configuration/config.toml).
_Note that increasing the limit may lead to long loading times and large memory consumption
of the client's browser and the Streamlit server._
"""
).strip("\n")
class BadDurationStringError(StreamlitAPIException):
"""Raised when a bad duration argument string is passed."""
def __init__(self, duration: str):
MarkdownFormattedException.__init__(
self,
"TTL string doesn't look right. It should be formatted as"
f"`'1d2h34m'` or `2 days`, for example. Got: {duration}",
)
def is_cacheable_msg(msg: ForwardMsg) -> bool:
"""True if the given message qualifies for caching."""
if msg.WhichOneof("type") in {"ref_hash", "initialize"}:
# Some message types never get cached
return False
return msg.ByteSize() >= int(config.get_option("global.minCachedMessageSize"))
def serialize_forward_msg(msg: ForwardMsg) -> bytes:
"""Serialize a ForwardMsg to send to a client.
If the message is too large, it will be converted to an exception message
instead.
"""
populate_hash_if_needed(msg)
msg_str = msg.SerializeToString()
if len(msg_str) > get_max_message_size_bytes():
import streamlit.elements.exception as exception
# Overwrite the offending ForwardMsg.delta with an error to display.
# This assumes that the size limit wasn't exceeded due to metadata.
exception.marshall(msg.delta.new_element.exception, MessageSizeError(msg_str))
msg_str = msg.SerializeToString()
return msg_str
# This needs to be initialized lazily to avoid calling config.get_option() and
# thus initializing config options when this file is first imported.
_max_message_size_bytes: int | None = None
def get_max_message_size_bytes() -> int:
"""Returns the max websocket message size in bytes.
This will lazyload the value from the config and store it in the global symbol table.
"""
global _max_message_size_bytes
if _max_message_size_bytes is None:
_max_message_size_bytes = config.get_option("server.maxMessageSize") * int(1e6)
return _max_message_size_bytes

View File

@@ -0,0 +1,46 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
from dataclasses import dataclass, field
@dataclass(frozen=True)
class ScriptData:
"""Contains parameters related to running a script."""
main_script_path: str
is_hello: bool = False
script_folder: str = field(init=False)
name: str = field(init=False)
def __post_init__(self) -> None:
"""Set some computed values derived from main_script_path.
The usage of object.__setattr__ is necessary because trying to set
self.script_folder or self.name normally, even within the __init__ method, will
explode since we declared this dataclass to be frozen.
We do this in __post_init__ so that we can use the auto-generated __init__
method that most dataclasses use.
"""
main_script_path = os.path.abspath(self.main_script_path)
script_folder = os.path.dirname(main_script_path)
object.__setattr__(self, "script_folder", script_folder)
basename = os.path.basename(main_script_path)
name = str(os.path.splitext(basename)[0])
object.__setattr__(self, "name", name)

View File

@@ -0,0 +1,38 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from streamlit.runtime.scriptrunner.script_runner import ScriptRunner, ScriptRunnerEvent
from streamlit.runtime.scriptrunner_utils.exceptions import (
RerunException,
StopException,
)
from streamlit.runtime.scriptrunner_utils.script_requests import RerunData
from streamlit.runtime.scriptrunner_utils.script_run_context import (
ScriptRunContext,
add_script_run_ctx,
enqueue_message,
get_script_run_ctx,
)
__all__ = [
"RerunData",
"ScriptRunContext",
"add_script_run_ctx",
"get_script_run_ctx",
"enqueue_message",
"RerunException",
"ScriptRunner",
"ScriptRunnerEvent",
"StopException",
]

View File

@@ -0,0 +1,159 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, Any, Callable
from streamlit import util
from streamlit.delta_generator_singletons import (
context_dg_stack,
get_default_dg_stack_value,
)
from streamlit.error_util import handle_uncaught_app_exception
from streamlit.errors import FragmentHandledException
from streamlit.runtime.scriptrunner_utils.exceptions import (
RerunException,
StopException,
)
if TYPE_CHECKING:
from streamlit.runtime.scriptrunner_utils.script_requests import RerunData
from streamlit.runtime.scriptrunner_utils.script_run_context import ScriptRunContext
class modified_sys_path:
"""A context for prepending a directory to sys.path for a second.
Code inspired by IPython:
Source: https://github.com/ipython/ipython/blob/master/IPython/utils/syspathcontext.py#L42
"""
def __init__(self, main_script_path: str):
self._main_script_path = main_script_path
self._added_path = False
def __repr__(self) -> str:
return util.repr_(self)
def __enter__(self):
if self._main_script_path not in sys.path:
sys.path.insert(0, self._main_script_path)
self._added_path = True
def __exit__(self, type, value, traceback):
if self._added_path:
try:
sys.path.remove(self._main_script_path)
except ValueError:
# It's already removed.
pass
# Returning False causes any exceptions to be re-raised.
return False
def exec_func_with_error_handling(
func: Callable[[], Any], ctx: ScriptRunContext
) -> tuple[
Any | None,
bool,
RerunData | None,
bool,
Exception | None,
]:
"""Execute the passed function wrapped in a try/except block.
This function is called by the script runner to execute the user's script or
fragment reruns, but also for the execution of fragment code in context of a normal
app run. This wrapper ensures that handle_uncaught_exception messages show up in the
correct context.
Parameters
----------
func : callable
The function to execute wrapped in the try/except block.
ctx : ScriptRunContext
The context in which the script is being run.
Returns
-------
tuple
A tuple containing:
- The result of the passed function.
- A boolean indicating whether the script ran without errors (RerunException and
StopException don't count as errors).
- The RerunData instance belonging to a RerunException if the script was
interrupted by a RerunException.
- A boolean indicating whether the script was stopped prematurely (False for
RerunExceptions, True for all other exceptions).
- The uncaught exception if one occurred, None otherwise
"""
run_without_errors = True
# This will be set to a RerunData instance if our execution
# is interrupted by a RerunException.
rerun_exception_data: RerunData | None = None
# If the script stops early, we don't want to remove unseen widgets,
# so we track this to potentially skip session state cleanup later.
premature_stop: bool = False
# The result of the passed function
result: Any | None = None
# The uncaught exception if one occurred, None otherwise
uncaught_exception: Exception | None = None
try:
result = func()
except RerunException as e:
rerun_exception_data = e.rerun_data
# Since the script is about to rerun, we may need to reset our cursors/dg_stack
# so that we write to the right place in the app. For full script runs, this
# needs to happen in case the same thread reruns our script (a different thread
# would automatically come with fresh cursors/dg_stack values). For fragments,
# it doesn't matter either way since the fragment resets these values from its
# snapshot before execution.
ctx.cursors.clear()
context_dg_stack.set(get_default_dg_stack_value())
# Interruption due to a rerun is usually from `st.rerun()`, which
# we want to count as a script completion so triggers reset.
# It is also possible for this to happen if fast reruns is off,
# but this is very rare.
premature_stop = False
except StopException:
# This is thrown when the script executes `st.stop()`.
# We don't have to do anything here.
premature_stop = True
except FragmentHandledException:
run_without_errors = False
premature_stop = True
except Exception as ex:
run_without_errors = False
premature_stop = True
handle_uncaught_app_exception(ex)
uncaught_exception = ex
return (
result,
run_without_errors,
rerun_exception_data,
premature_stop,
uncaught_exception,
)

View File

@@ -0,0 +1,273 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import ast
import sys
from typing import Any, Final
from streamlit import config
# When a Streamlit app is magicified, we insert a `magic_funcs` import near the top of
# its module's AST:
# import streamlit.runtime.scriptrunner.magic_funcs as __streamlitmagic__
MAGIC_MODULE_NAME: Final = "__streamlitmagic__"
def add_magic(code: str, script_path: str) -> Any:
"""Modifies the code to support magic Streamlit commands.
Parameters
----------
code : str
The Python code.
script_path : str
The path to the script file.
Returns
-------
ast.Module
The syntax tree for the code.
"""
# Pass script_path so we get pretty exceptions.
tree = ast.parse(code, script_path, "exec")
file_ends_in_semicolon = _does_file_end_in_semicolon(tree, code)
_modify_ast_subtree(
tree, is_root=True, file_ends_in_semicolon=file_ends_in_semicolon
)
return tree
def _modify_ast_subtree(
tree: Any,
body_attr: str = "body",
is_root: bool = False,
file_ends_in_semicolon: bool = False,
):
"""Parses magic commands and modifies the given AST (sub)tree."""
body = getattr(tree, body_attr)
for i, node in enumerate(body):
node_type = type(node)
# Recursively parses the content of the statements
# `with` as well as function definitions.
# Also covers their async counterparts
if (
node_type is ast.FunctionDef
or node_type is ast.With
or node_type is ast.AsyncFunctionDef
or node_type is ast.AsyncWith
):
_modify_ast_subtree(node)
# Recursively parses the content of the statements
# `for` and `while`.
# Also covers their async counterparts
elif (
node_type is ast.For or node_type is ast.While or node_type is ast.AsyncFor
):
_modify_ast_subtree(node)
_modify_ast_subtree(node, "orelse")
# Recursively parses methods in a class.
elif node_type is ast.ClassDef:
for inner_node in node.body:
if type(inner_node) in {ast.FunctionDef, ast.AsyncFunctionDef}:
_modify_ast_subtree(inner_node)
# Recursively parses the contents of try statements,
# all their handlers (except and else) and the finally body
elif node_type is ast.Try or (
sys.version_info >= (3, 11) and node_type is ast.TryStar
):
_modify_ast_subtree(node)
_modify_ast_subtree(node, body_attr="finalbody")
_modify_ast_subtree(node, body_attr="orelse")
for handler_node in node.handlers:
_modify_ast_subtree(handler_node)
# Recursively parses if blocks, as well as their else/elif blocks
# (else/elif are both mapped to orelse)
# it intentionally does not parse the test expression.
elif node_type is ast.If:
_modify_ast_subtree(node)
_modify_ast_subtree(node, "orelse")
elif sys.version_info >= (3, 10) and node_type is ast.Match:
for case_node in node.cases:
_modify_ast_subtree(case_node)
# Convert standalone expression nodes to st.write
elif node_type is ast.Expr:
value = _get_st_write_from_expr(
node,
i,
parent_type=type(tree),
is_root=is_root,
is_last_expr=(i == len(body) - 1),
file_ends_in_semicolon=file_ends_in_semicolon,
)
if value is not None:
node.value = value
if is_root:
# Import Streamlit so we can use it in the new_value above.
_insert_import_statement(tree)
ast.fix_missing_locations(tree)
def _insert_import_statement(tree: Any) -> None:
"""Insert Streamlit import statement at the top(ish) of the tree."""
st_import = _build_st_import_statement()
# If the 0th node is already an import statement, put the Streamlit
# import below that, so we don't break "from __future__ import".
if tree.body and type(tree.body[0]) in {ast.ImportFrom, ast.Import}:
tree.body.insert(1, st_import)
# If the 0th node is a docstring and the 1st is an import statement,
# put the Streamlit import below those, so we don't break "from
# __future__ import".
elif (
len(tree.body) > 1
and (
type(tree.body[0]) is ast.Expr
and _is_string_constant_node(tree.body[0].value)
)
and type(tree.body[1]) in {ast.ImportFrom, ast.Import}
):
tree.body.insert(2, st_import)
else:
tree.body.insert(0, st_import)
def _build_st_import_statement():
"""Build AST node for `import magic_funcs as __streamlitmagic__`."""
return ast.Import(
names=[
ast.alias(
name="streamlit.runtime.scriptrunner.magic_funcs",
asname=MAGIC_MODULE_NAME,
)
]
)
def _build_st_write_call(nodes):
"""Build AST node for `__streamlitmagic__.transparent_write(*nodes)`."""
return ast.Call(
func=ast.Attribute(
attr="transparent_write",
value=ast.Name(id=MAGIC_MODULE_NAME, ctx=ast.Load()),
ctx=ast.Load(),
),
args=nodes,
keywords=[],
)
def _get_st_write_from_expr(
node, i, parent_type, is_root, is_last_expr, file_ends_in_semicolon
):
# Don't wrap function calls
# (Unless the function call happened at the end of the root node, AND
# magic.displayLastExprIfNoSemicolon is True. This allows us to support notebook-like
# behavior, where we display the last function in a cell)
if type(node.value) is ast.Call and not _is_displayable_last_expr(
is_root, is_last_expr, file_ends_in_semicolon
):
return None
# Don't wrap DocString nodes
# (Unless magic.displayRootDocString, in which case we do wrap the root-level
# docstring with st.write. This allows us to support notebook-like behavior
# where you can have a cell with a markdown string)
if _is_docstring_node(
node.value, i, parent_type
) and not _should_display_docstring_like_node_anyway(is_root):
return None
# Don't wrap yield nodes
if type(node.value) is ast.Yield or type(node.value) is ast.YieldFrom:
return None
# Don't wrap await nodes
if type(node.value) is ast.Await:
return None
# If tuple, call st.write(*the_tuple). This allows us to add a comma at the end of a
# statement to turn it into an expression that should be st-written. Ex:
# "np.random.randn(1000, 2),"
args = node.value.elts if type(node.value) is ast.Tuple else [node.value]
return _build_st_write_call(args)
def _is_string_constant_node(node) -> bool:
return isinstance(node, ast.Constant) and isinstance(node.value, str)
def _is_docstring_node(node, node_index, parent_type) -> bool:
return (
node_index == 0
and _is_string_constant_node(node)
and parent_type in {ast.FunctionDef, ast.AsyncFunctionDef, ast.Module}
)
def _does_file_end_in_semicolon(tree, code: str) -> bool:
file_ends_in_semicolon = False
# Avoid spending time with this operation if magic.displayLastExprIfNoSemicolon is
# not set.
if config.get_option("magic.displayLastExprIfNoSemicolon"):
if len(tree.body) == 0:
return False
last_line_num = getattr(tree.body[-1], "end_lineno", None)
if last_line_num is not None:
last_line_str: str = code.split("\n")[last_line_num - 1]
file_ends_in_semicolon = last_line_str.strip(" ").endswith(";")
return file_ends_in_semicolon
def _is_displayable_last_expr(
is_root: bool, is_last_expr: bool, file_ends_in_semicolon: bool
) -> bool:
return (
# This is a "displayable last expression" if...
# ...it's actually the last expression...
is_last_expr
# ...in the root scope...
and is_root
# ...it does not end in a semicolon...
and not file_ends_in_semicolon
# ...and this config option is telling us to show it
and config.get_option("magic.displayLastExprIfNoSemicolon")
)
def _should_display_docstring_like_node_anyway(is_root: bool) -> bool:
return config.get_option("magic.displayRootDocString") and is_root

View File

@@ -0,0 +1,32 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Any
from streamlit.runtime.metrics_util import gather_metrics
@gather_metrics("magic")
def transparent_write(*args: Any) -> Any:
"""The function that gets magic-ified into Streamlit apps.
This is just st.write, but returns the arguments you passed to it.
"""
import streamlit as st
st.write(*args)
if len(args) == 1:
return args[0]
return args

View File

@@ -0,0 +1,89 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os.path
import threading
from typing import Any
from streamlit import config
from streamlit.runtime.scriptrunner import magic
from streamlit.source_util import open_python_file
class ScriptCache:
"""Thread-safe cache of Python script bytecode."""
def __init__(self):
# Mapping of script_path: bytecode
self._cache: dict[str, Any] = {}
self._lock = threading.Lock()
def clear(self) -> None:
"""Remove all entries from the cache.
Notes
-----
Threading: SAFE. May be called on any thread.
"""
with self._lock:
self._cache.clear()
def get_bytecode(self, script_path: str) -> Any:
"""Return the bytecode for the Python script at the given path.
If the bytecode is not already in the cache, the script will be
compiled first.
Raises
------
Any Exception raised while reading or compiling the script.
Notes
-----
Threading: SAFE. May be called on any thread.
"""
script_path = os.path.abspath(script_path)
with self._lock:
bytecode = self._cache.get(script_path, None)
if bytecode is not None:
# Fast path: the code is already cached.
return bytecode
# Populate the cache
with open_python_file(script_path) as f:
filebody = f.read()
if config.get_option("runner.magicEnabled"):
filebody = magic.add_magic(filebody, script_path)
bytecode = compile( # type: ignore
filebody,
# Pass in the file path so it can show up in exceptions.
script_path,
# We're compiling entire blocks of Python, so we need "exec"
# mode (as opposed to "eval" or "single").
mode="exec",
# Don't inherit any flags or "future" statements.
flags=0,
dont_inherit=1,
# Use the default optimization options.
optimize=-1,
)
self._cache[script_path] = bytecode
return bytecode

View File

@@ -0,0 +1,756 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import gc
import sys
import threading
import types
from contextlib import contextmanager
from enum import Enum
from timeit import default_timer as timer
from typing import TYPE_CHECKING, Callable, Final, Literal, cast
from blinker import Signal
from streamlit import config, runtime, util
from streamlit.errors import FragmentStorageKeyError
from streamlit.logger import get_logger
from streamlit.proto.ClientState_pb2 import ClientState
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime.metrics_util import (
create_page_profile_message,
to_microseconds,
)
from streamlit.runtime.pages_manager import PagesManager
from streamlit.runtime.scriptrunner.exec_code import (
exec_func_with_error_handling,
modified_sys_path,
)
from streamlit.runtime.scriptrunner.script_cache import ScriptCache
from streamlit.runtime.scriptrunner_utils.exceptions import (
RerunException,
StopException,
)
from streamlit.runtime.scriptrunner_utils.script_requests import (
RerunData,
ScriptRequests,
ScriptRequestType,
)
from streamlit.runtime.scriptrunner_utils.script_run_context import (
ScriptRunContext,
add_script_run_ctx,
get_script_run_ctx,
)
from streamlit.runtime.state import (
SCRIPT_RUN_WITHOUT_ERRORS_KEY,
SafeSessionState,
SessionState,
)
from streamlit.source_util import page_sort_key
if TYPE_CHECKING:
from streamlit.runtime.fragment import FragmentStorage
from streamlit.runtime.scriptrunner.script_cache import ScriptCache
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
_LOGGER: Final = get_logger(__name__)
class ScriptRunnerEvent(Enum):
# "Control" events. These are emitted when the ScriptRunner's state changes.
# The script started running.
SCRIPT_STARTED = "SCRIPT_STARTED"
# The script run stopped because of a compile error.
SCRIPT_STOPPED_WITH_COMPILE_ERROR = "SCRIPT_STOPPED_WITH_COMPILE_ERROR"
# The script run stopped because it ran to completion, or was
# interrupted by the user.
SCRIPT_STOPPED_WITH_SUCCESS = "SCRIPT_STOPPED_WITH_SUCCESS"
# The script run stopped in order to start a script run with newer widget state.
SCRIPT_STOPPED_FOR_RERUN = "SCRIPT_STOPPED_FOR_RERUN"
# The script run corresponding to a fragment ran to completion, or was interrupted
# by the user.
FRAGMENT_STOPPED_WITH_SUCCESS = "FRAGMENT_STOPPED_WITH_SUCCESS"
# The ScriptRunner is done processing the ScriptEventQueue and
# is shut down.
SHUTDOWN = "SHUTDOWN"
# "Data" events. These are emitted when the ScriptRunner's script has
# data to send to the frontend.
# The script has a ForwardMsg to send to the frontend.
ENQUEUE_FORWARD_MSG = "ENQUEUE_FORWARD_MSG"
"""
Note [Threading]
There are two kinds of threads in Streamlit, the main thread and script threads.
The main thread is started by invoking the Streamlit CLI, and bootstraps the
framework and runs the Tornado webserver.
A script thread is created by a ScriptRunner when it starts. The script thread
is where the ScriptRunner executes, including running the user script itself,
processing messages to/from the frontend, and all the Streamlit library function
calls in the user script.
It is possible for the user script to spawn its own threads, which could call
Streamlit functions. We restrict the ScriptRunner's execution control to the
script thread. Calling Streamlit functions from other threads is unlikely to
work correctly due to lack of ScriptRunContext, so we may add a guard against
it in the future.
"""
# For projects that have a pages folder, we assume that this is a script that
# is designed to leverage our original v1 version of multi-page apps. This
# function will be called to run the script in lieu of the main script. This
# function simulates the v1 setup using the modern v2 commands (st.navigation)
def _mpa_v1(main_script_path: str):
from pathlib import Path
from streamlit.commands.navigation import PageType, _navigation
from streamlit.navigation.page import StreamlitPage
# Select the folder that should be used for the pages:
MAIN_SCRIPT_PATH = Path(main_script_path).resolve()
PAGES_FOLDER = MAIN_SCRIPT_PATH.parent / "pages"
# Read out the my_pages folder and create a page for every script:
pages = PAGES_FOLDER.glob("*.py")
pages = sorted(
[page for page in pages if page.name.endswith(".py")], key=page_sort_key
)
# Use this script as the main page and
main_page = StreamlitPage(MAIN_SCRIPT_PATH, default=True)
all_pages = [main_page] + [
StreamlitPage(PAGES_FOLDER / page.name) for page in pages
]
# Initialize the navigation with all the pages:
position: Literal["sidebar", "hidden"] = (
"hidden"
if config.get_option("client.showSidebarNavigation") is False
else "sidebar"
)
page = _navigation(
cast("list[PageType]", all_pages),
position=position,
expanded=False,
)
if page._page != main_page._page:
# Only run the page if it is not pointing to this script:
page.run()
# Finish the script execution here to only run the selected page
raise StopException()
class ScriptRunner:
def __init__(
self,
session_id: str,
main_script_path: str,
session_state: SessionState,
uploaded_file_mgr: UploadedFileManager,
script_cache: ScriptCache,
initial_rerun_data: RerunData,
user_info: dict[str, str | bool | None],
fragment_storage: FragmentStorage,
pages_manager: PagesManager,
):
"""Initialize the ScriptRunner.
(The ScriptRunner won't start executing until start() is called.)
Parameters
----------
session_id
The AppSession's id.
main_script_path
Path to our main app script.
session_state
The AppSession's SessionState instance.
uploaded_file_mgr
The File manager to store the data uploaded by the file_uploader widget.
script_cache
A ScriptCache instance.
initial_rerun_data
RerunData to initialize this ScriptRunner with.
user_info
A dict that contains information about the current user. For now,
it only contains the user's email address.
{
"email": "example@example.com"
}
Information about the current user is optionally provided when a
websocket connection is initialized via the "X-Streamlit-User" header.
fragment_storage
The AppSession's FragmentStorage instance.
"""
self._session_id = session_id
self._main_script_path = main_script_path
self._session_state = SafeSessionState(
session_state, yield_callback=self._maybe_handle_execution_control_request
)
self._uploaded_file_mgr = uploaded_file_mgr
self._script_cache = script_cache
self._user_info = user_info
self._fragment_storage = fragment_storage
self._pages_manager = pages_manager
self._requests = ScriptRequests()
self._requests.request_rerun(initial_rerun_data)
self.on_event = Signal(
doc="""Emitted when a ScriptRunnerEvent occurs.
This signal is generally emitted on the ScriptRunner's script
thread (which is *not* the same thread that the ScriptRunner was
created on).
Parameters
----------
sender: ScriptRunner
The sender of the event (this ScriptRunner).
event : ScriptRunnerEvent
forward_msg : ForwardMsg | None
The ForwardMsg to send to the frontend. Set only for the
ENQUEUE_FORWARD_MSG event.
exception : BaseException | None
Our compile error. Set only for the
SCRIPT_STOPPED_WITH_COMPILE_ERROR event.
widget_states : streamlit.proto.WidgetStates_pb2.WidgetStates | None
The ScriptRunner's final WidgetStates. Set only for the
SHUTDOWN event.
"""
)
# Set to true while we're executing. Used by
# _maybe_handle_execution_control_request.
self._execing = False
# This is initialized in start()
self._script_thread: threading.Thread | None = None
def __repr__(self) -> str:
return util.repr_(self)
def request_stop(self) -> None:
"""Request that the ScriptRunner stop running its script and
shut down. The ScriptRunner will handle this request when it reaches
an interrupt point.
Safe to call from any thread.
"""
self._requests.request_stop()
def request_rerun(self, rerun_data: RerunData) -> bool:
"""Request that the ScriptRunner interrupt its currently-running
script and restart it.
If the ScriptRunner has been stopped, this request can't be honored:
return False.
Otherwise, record the request and return True. The ScriptRunner will
handle the rerun request as soon as it reaches an interrupt point.
Safe to call from any thread.
"""
return self._requests.request_rerun(rerun_data)
def start(self) -> None:
"""Start a new thread to process the ScriptEventQueue.
This must be called only once.
"""
if self._script_thread is not None:
raise Exception("ScriptRunner was already started")
self._script_thread = threading.Thread(
target=self._run_script_thread,
name="ScriptRunner.scriptThread",
)
self._script_thread.start()
def _get_script_run_ctx(self) -> ScriptRunContext:
"""Get the ScriptRunContext for the current thread.
Returns
-------
ScriptRunContext
The ScriptRunContext for the current thread.
Raises
------
AssertionError
If called outside of a ScriptRunner thread.
RuntimeError
If there is no ScriptRunContext for the current thread.
"""
assert self._is_in_script_thread()
ctx = get_script_run_ctx()
if ctx is None:
# This should never be possible on the script_runner thread.
raise RuntimeError(
"ScriptRunner thread has a null ScriptRunContext. "
"Something has gone very wrong!"
)
return ctx
def _run_script_thread(self) -> None:
"""The entry point for the script thread.
Processes the ScriptRequestQueue, which will at least contain the RERUN
request that will trigger the first script-run.
When the ScriptRequestQueue is empty, or when a SHUTDOWN request is
dequeued, this function will exit and its thread will terminate.
"""
assert self._is_in_script_thread()
_LOGGER.debug("Beginning script thread")
# Create and attach the thread's ScriptRunContext
ctx = ScriptRunContext(
session_id=self._session_id,
_enqueue=self._enqueue_forward_msg,
script_requests=self._requests,
query_string="",
session_state=self._session_state,
uploaded_file_mgr=self._uploaded_file_mgr,
main_script_path=self._main_script_path,
user_info=self._user_info,
gather_usage_stats=bool(config.get_option("browser.gatherUsageStats")),
fragment_storage=self._fragment_storage,
pages_manager=self._pages_manager,
context_info=None,
)
add_script_run_ctx(threading.current_thread(), ctx)
request = self._requests.on_scriptrunner_ready()
while request.type == ScriptRequestType.RERUN:
# When the script thread starts, we'll have a pending rerun
# request that we'll handle immediately. When the script finishes,
# it's possible that another request has come in that we need to
# handle, which is why we call _run_script in a loop.
self._run_script(request.rerun_data)
request = self._requests.on_scriptrunner_ready()
assert request.type == ScriptRequestType.STOP
# Send a SHUTDOWN event before exiting, so some state can be saved
# for use in a future script run when not triggered by the client.
client_state = ClientState()
client_state.query_string = ctx.query_string
client_state.page_script_hash = ctx.page_script_hash
self.on_event.send(
self, event=ScriptRunnerEvent.SHUTDOWN, client_state=client_state
)
def _is_in_script_thread(self) -> bool:
"""True if the calling function is running in the script thread."""
return self._script_thread == threading.current_thread()
def _enqueue_forward_msg(self, msg: ForwardMsg) -> None:
"""Enqueue a ForwardMsg to our browser queue.
This private function is called by ScriptRunContext only.
It may be called from the script thread OR the main thread.
"""
# Whenever we enqueue a ForwardMsg, we also handle any pending
# execution control request. This means that a script can be
# cleanly interrupted and stopped inside most `st.foo` calls.
self._maybe_handle_execution_control_request()
# Pass the message to our associated AppSession.
self.on_event.send(
self, event=ScriptRunnerEvent.ENQUEUE_FORWARD_MSG, forward_msg=msg
)
def _maybe_handle_execution_control_request(self) -> None:
"""Check our current ScriptRequestState to see if we have a
pending STOP or RERUN request.
This function is called every time the app script enqueues a
ForwardMsg, which means that most `st.foo` commands - which generally
involve sending a ForwardMsg to the frontend - act as implicit
yield points in the script's execution.
"""
if not self._is_in_script_thread():
# We can only handle execution_control_request if we're on the
# script execution thread. However, it's possible for deltas to
# be enqueued (and, therefore, for this function to be called)
# in separate threads, so we check for that here.
return
if not self._execing:
# If the _execing flag is not set, we're not actually inside
# an exec() call. This happens when our script exec() completes,
# we change our state to STOPPED, and a statechange-listener
# enqueues a new ForwardEvent
return
request = self._requests.on_scriptrunner_yield()
if request is None:
# No RERUN or STOP request.
return
if request.type == ScriptRequestType.RERUN:
raise RerunException(request.rerun_data)
assert request.type == ScriptRequestType.STOP
raise StopException()
@contextmanager
def _set_execing_flag(self):
"""A context for setting the ScriptRunner._execing flag.
Used by _maybe_handle_execution_control_request to ensure that
we only handle requests while we're inside an exec() call
"""
if self._execing:
raise RuntimeError("Nested set_execing_flag call")
self._execing = True
try:
yield
finally:
self._execing = False
def _run_script(self, rerun_data: RerunData) -> None:
"""Run our script.
Parameters
----------
rerun_data: RerunData
The RerunData to use.
"""
assert self._is_in_script_thread()
# An explicit loop instead of recursion to avoid stack overflows
while True:
_LOGGER.debug("Running script %s", rerun_data)
start_time: float = timer()
prep_time: float = 0 # This will be overwritten once preparations are done.
if not rerun_data.fragment_id_queue:
# Don't clear session refs for media files if we're running a fragment.
# Otherwise, we're likely to remove files that still have corresponding
# download buttons/links to them present in the app, which will result
# in a 404 should the user click on them.
runtime.get_instance().media_file_mgr.clear_session_refs()
self._pages_manager.set_script_intent(
rerun_data.page_script_hash, rerun_data.page_name
)
active_script = self._pages_manager.get_initial_active_script(
rerun_data.page_script_hash, rerun_data.page_name
)
main_page_info = self._pages_manager.get_main_page()
page_script_hash = (
active_script["page_script_hash"]
if active_script is not None
else main_page_info["page_script_hash"]
)
ctx = self._get_script_run_ctx()
# Clear widget state on page change. This normally happens implicitly
# in the script run cleanup steps, but doing it explicitly ensures
# it happens even if a script run was interrupted.
previous_page_script_hash = ctx.page_script_hash
if previous_page_script_hash != page_script_hash:
# Page changed, enforce reset widget state where possible.
# This enforcement matters when a new script thread is started
# before the previous script run is completed (from user
# interaction). Use the widget ids from the rerun data to
# maintain some widget state, as the rerun data should
# contain the latest widget ids from the frontend.
widget_ids: set[str] = set()
if (
rerun_data.widget_states is not None
and rerun_data.widget_states.widgets is not None
):
widget_ids = {w.id for w in rerun_data.widget_states.widgets}
self._session_state.on_script_finished(widget_ids)
fragment_ids_this_run = list(rerun_data.fragment_id_queue)
ctx.reset(
query_string=rerun_data.query_string,
page_script_hash=page_script_hash,
fragment_ids_this_run=fragment_ids_this_run,
context_info=rerun_data.context_info,
)
self.on_event.send(
self,
event=ScriptRunnerEvent.SCRIPT_STARTED,
page_script_hash=page_script_hash,
fragment_ids_this_run=fragment_ids_this_run,
pages=self._pages_manager.get_pages(),
)
# Compile the script. Any errors thrown here will be surfaced
# to the user via a modal dialog in the frontend, and won't result
# in their previous script elements disappearing.
try:
if active_script is not None:
script_path = active_script["script_path"]
else:
# page must not be found
script_path = main_page_info["script_path"]
# At this point, we know that either
# * the script corresponding to the hash requested no longer
# exists, or
# * we were not able to find a script with the requested page
# name.
# In both of these cases, we want to send a page_not_found
# message to the frontend.
msg = ForwardMsg()
msg.page_not_found.page_name = rerun_data.page_name
ctx.enqueue(msg)
code = self._script_cache.get_bytecode(script_path)
except Exception as ex:
# We got a compile error. Send an error event and bail immediately.
_LOGGER.exception("Script compilation error", exc_info=ex)
self._session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY] = False
self.on_event.send(
self,
event=ScriptRunnerEvent.SCRIPT_STOPPED_WITH_COMPILE_ERROR,
exception=ex,
)
return
# If we get here, we've successfully compiled our script. The next step
# is to run it. Errors thrown during execution will be shown to the
# user as ExceptionElements.
# Create fake module. This gives us a name global namespace to
# execute the code in.
module = self._new_module("__main__")
# Install the fake module as the __main__ module. This allows
# the pickle module to work inside the user's code, since it now
# can know the module where the pickled objects stem from.
# IMPORTANT: This means we can't use "if __name__ == '__main__'" in
# our code, as it will point to the wrong module!!!
sys.modules["__main__"] = module
# Add special variables to the module's globals dict.
# Note: The following is a requirement for the CodeHasher to
# work correctly. The CodeHasher is scoped to
# files contained in the directory of __main__.__file__, which we
# assume is the main script directory.
module.__dict__["__file__"] = script_path
def code_to_exec(code=code, module=module, ctx=ctx, rerun_data=rerun_data):
with (
modified_sys_path(self._main_script_path),
self._set_execing_flag(),
):
# Run callbacks for widgets whose values have changed.
if rerun_data.widget_states is not None:
self._session_state.on_script_will_rerun(
rerun_data.widget_states
)
ctx.on_script_start()
if rerun_data.fragment_id_queue:
for fragment_id in rerun_data.fragment_id_queue:
try:
wrapped_fragment = self._fragment_storage.get(
fragment_id
)
wrapped_fragment()
except FragmentStorageKeyError:
# This can happen if the fragment_id is removed from the
# storage before the script runner gets to it. In this
# case, the fragment is simply skipped.
# Also, only log an error if the fragment is not an
# auto_rerun to avoid noise. If it is an auto_rerun, we
# might have a race condition where the fragment_id is
# removed but the webapp sends a rerun request before
# the removal information has reached the web app
# (see https://github.com/streamlit/streamlit/issues/9080).
if not rerun_data.is_auto_rerun:
_LOGGER.warning(
f"Couldn't find fragment with id {fragment_id}."
" This can happen if the fragment does not"
" exist anymore when this request is processed,"
" for example because a full app rerun happened"
" that did not register the fragment."
" Usually this doesn't happen or no action is"
" required, so its mainly for debugging."
)
except (RerunException, StopException) as e:
# The wrapped_fragment function is executed
# inside of a exec_func_with_error_handling call, so
# there is a correct handler for these exceptions.
raise e
except Exception:
# Ignore exceptions raised by fragments here as we don't
# want to stop the execution of other fragments. The
# error itself is already rendered within the wrapped
# fragment.
pass
else:
if PagesManager.uses_pages_directory:
_mpa_v1(self._main_script_path)
exec(code, module.__dict__)
self._fragment_storage.clear(
new_fragment_ids=ctx.new_fragment_ids
)
self._session_state.maybe_check_serializable()
# check for control requests, e.g. rerun requests have arrived
self._maybe_handle_execution_control_request()
prep_time = timer() - start_time
(
_,
run_without_errors,
rerun_exception_data,
premature_stop,
uncaught_exception,
) = exec_func_with_error_handling(code_to_exec, ctx)
# setting the session state here triggers a yield-callback call
# which reads self._requests and checks for rerun data
self._session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY] = run_without_errors
if rerun_exception_data:
# The handling for when a full script run or a fragment is stopped early
# is the same, so we only have one ScriptRunnerEvent for this scenario.
finished_event = ScriptRunnerEvent.SCRIPT_STOPPED_FOR_RERUN
elif rerun_data.fragment_id_queue:
finished_event = ScriptRunnerEvent.FRAGMENT_STOPPED_WITH_SUCCESS
else:
finished_event = ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
if ctx.gather_usage_stats:
try:
# Create and send page profile information
ctx.enqueue(
create_page_profile_message(
commands=ctx.tracked_commands,
exec_time=to_microseconds(timer() - start_time),
prep_time=to_microseconds(prep_time),
uncaught_exception=(
type(uncaught_exception).__name__
if uncaught_exception
else None
),
)
)
except Exception as ex:
# Always capture all exceptions since we want to make sure that
# the telemetry never causes any issues.
_LOGGER.debug("Failed to create page profile", exc_info=ex)
self._on_script_finished(ctx, finished_event, premature_stop)
# # Use _log_if_error() to make sure we never ever ever stop running the
# # script without meaning to.
_log_if_error(_clean_problem_modules)
if rerun_exception_data is not None:
rerun_data = rerun_exception_data
else:
break
def _on_script_finished(
self, ctx: ScriptRunContext, event: ScriptRunnerEvent, premature_stop: bool
) -> None:
"""Called when our script finishes executing, even if it finished
early with an exception. We perform post-run cleanup here.
"""
# Tell session_state to update itself in response
if not premature_stop:
self._session_state.on_script_finished(ctx.widget_ids_this_run)
# Signal that the script has finished. (We use SCRIPT_STOPPED_WITH_SUCCESS
# even if we were stopped with an exception.)
self.on_event.send(self, event=event)
# Remove orphaned files now that the script has run and files in use
# are marked as active.
runtime.get_instance().media_file_mgr.remove_orphaned_files()
# Force garbage collection to run, to help avoid memory use building up
# This is usually not an issue, but sometimes GC takes time to kick in and
# causes apps to go over resource limits, and forcing it to run between
# script runs is low cost, since we aren't doing much work anyway.
if config.get_option("runner.postScriptGC"):
gc.collect(2)
def _new_module(self, name: str) -> types.ModuleType:
"""Create a new module with the given name."""
return types.ModuleType(name)
def _clean_problem_modules() -> None:
"""Some modules are stateful, so we have to clear their state."""
if "keras" in sys.modules:
try:
keras = sys.modules["keras"]
keras.backend.clear_session()
except Exception:
# We don't want to crash the app if we can't clear the Keras session.
pass
if "matplotlib.pyplot" in sys.modules:
try:
plt = sys.modules["matplotlib.pyplot"]
plt.close("all")
except Exception:
# We don't want to crash the app if we can't close matplotlib
pass
# The reason this is not a decorator is because we want to make it clear at the
# calling location that this function is being used.
def _log_if_error(fn: Callable[[], None]) -> None:
try:
fn()
except Exception as e:
_LOGGER.warning(e)

View File

@@ -0,0 +1,19 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The modules in this package are separated from
the scriptrunner-package, because they are more or less
standalone and other modules import them quite frequently.
This separation helps us to remove dependency cycles.
"""

View File

@@ -0,0 +1,48 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from streamlit.runtime.scriptrunner_utils.script_requests import RerunData
from streamlit.util import repr_
# We inherit from BaseException to avoid being caught by user code.
# For example, having it inherit from Exception might make st.rerun not
# work in a try/except block.
class ScriptControlException(BaseException): # NOSONAR
"""Base exception for ScriptRunner."""
pass
class StopException(ScriptControlException):
"""Silently stop the execution of the user's script."""
pass
class RerunException(ScriptControlException):
"""Silently stop and rerun the user's script."""
def __init__(self, rerun_data: RerunData):
"""Construct a RerunException.
Parameters
----------
rerun_data : RerunData
The RerunData that should be used to rerun the script
"""
self.rerun_data = rerun_data
def __repr__(self) -> str:
return repr_(self)

View File

@@ -0,0 +1,305 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import threading
from dataclasses import dataclass, field, replace
from enum import Enum
from typing import TYPE_CHECKING, cast
from streamlit import util
from streamlit.proto.Common_pb2 import ChatInputValue as ChatInputValueProto
from streamlit.proto.WidgetStates_pb2 import WidgetState, WidgetStates
if TYPE_CHECKING:
from streamlit.proto.ClientState_pb2 import ContextInfo
class ScriptRequestType(Enum):
# The ScriptRunner should continue running its script.
CONTINUE = "CONTINUE"
# If the script is running, it should be stopped as soon
# as the ScriptRunner reaches an interrupt point.
# This is a terminal state.
STOP = "STOP"
# A script rerun has been requested. The ScriptRunner should
# handle this request as soon as it reaches an interrupt point.
RERUN = "RERUN"
@dataclass(frozen=True)
class RerunData:
"""Data attached to RERUN requests. Immutable."""
query_string: str = ""
widget_states: WidgetStates | None = None
page_script_hash: str = ""
page_name: str = ""
# A single fragment_id to append to fragment_id_queue.
fragment_id: str | None = None
# The queue of fragment_ids waiting to be run.
fragment_id_queue: list[str] = field(default_factory=list)
is_fragment_scoped_rerun: bool = False
# set to true when a script is rerun by the fragment auto-rerun mechanism
is_auto_rerun: bool = False
# context_info is used to store information from the user browser (e.g. timezone)
context_info: ContextInfo | None = None
def __repr__(self) -> str:
return util.repr_(self)
@dataclass(frozen=True)
class ScriptRequest:
"""A STOP or RERUN request and associated data."""
type: ScriptRequestType
_rerun_data: RerunData | None = None
@property
def rerun_data(self) -> RerunData:
if self.type is not ScriptRequestType.RERUN:
raise RuntimeError("RerunData is only set for RERUN requests.")
return cast("RerunData", self._rerun_data)
def __repr__(self) -> str:
return util.repr_(self)
def _fragment_run_should_not_preempt_script(
fragment_id_queue: list[str],
is_fragment_scoped_rerun: bool,
) -> bool:
"""Returns whether the currently running script should be preempted due to a
fragment rerun.
Reruns corresponding to fragment runs that weren't caused by calls to
`st.rerun(scope="fragment")` should *not* cancel the current script run
as doing so will affect elements outside of the fragment.
"""
return bool(fragment_id_queue) and not is_fragment_scoped_rerun
def _coalesce_widget_states(
old_states: WidgetStates | None, new_states: WidgetStates | None
) -> WidgetStates | None:
"""Coalesce an older WidgetStates into a newer one, and return a new
WidgetStates containing the result.
For most widget values, we just take the latest version.
However, any trigger_values (which are set by buttons) that are True in
`old_states` will be set to True in the coalesced result, so that button
presses don't go missing.
"""
if not old_states and not new_states:
return None
elif not old_states:
return new_states
elif not new_states:
return old_states
states_by_id: dict[str, WidgetState] = {
wstate.id: wstate for wstate in new_states.widgets
}
trigger_value_types = [
("trigger_value", False),
("chat_input_value", ChatInputValueProto(data=None)),
]
for old_state in old_states.widgets:
for trigger_value_type, unset_value in trigger_value_types:
if (
old_state.WhichOneof("value") == trigger_value_type
and getattr(old_state, trigger_value_type) != unset_value
):
new_trigger_val = states_by_id.get(old_state.id)
# It should nearly always be the case that new_trigger_val is None
# here as trigger values are deleted from the client's WidgetStateManager
# as soon as a rerun_script BackMsg is sent to the server. Since it's
# impossible to test that the client sends us state in the expected
# format in a unit test, we test for this behavior in
# e2e_playwright/test_fragment_queue_test.py
if not new_trigger_val or (
# Ensure the corresponding new_state is also a trigger;
# otherwise, a widget that was previously a button/chat_input but no
# longer is could get a bad value.
new_trigger_val.WhichOneof("value") == trigger_value_type
# We only want to take the value of old_state if new_trigger_val is
# unset as the old value may be stale if a newer one was entered.
and getattr(new_trigger_val, trigger_value_type) == unset_value
):
states_by_id[old_state.id] = old_state
coalesced = WidgetStates()
coalesced.widgets.extend(states_by_id.values())
return coalesced
class ScriptRequests:
"""An interface for communicating with a ScriptRunner. Thread-safe.
AppSession makes requests of a ScriptRunner through this class, and
ScriptRunner handles those requests.
"""
def __init__(self):
self._lock = threading.Lock()
self._state = ScriptRequestType.CONTINUE
self._rerun_data = RerunData()
def request_stop(self) -> None:
"""Request that the ScriptRunner stop running. A stopped ScriptRunner
can't be used anymore. STOP requests succeed unconditionally.
"""
with self._lock:
self._state = ScriptRequestType.STOP
def request_rerun(self, new_data: RerunData) -> bool:
"""Request that the ScriptRunner rerun its script.
If the ScriptRunner has been stopped, this request can't be honored:
return False.
Otherwise, record the request and return True. The ScriptRunner will
handle the rerun request as soon as it reaches an interrupt point.
"""
with self._lock:
if self._state == ScriptRequestType.STOP:
# We can't rerun after being stopped.
return False
if self._state == ScriptRequestType.CONTINUE:
# The script is currently running, and we haven't received a request to
# rerun it as of yet. We can handle a rerun request unconditionally so
# just change self._state and set self._rerun_data.
self._state = ScriptRequestType.RERUN
# Convert from a single fragment_id into fragment_id_queue.
if new_data.fragment_id:
new_data = replace(
new_data,
fragment_id=None,
fragment_id_queue=[new_data.fragment_id],
)
self._rerun_data = new_data
return True
if self._state == ScriptRequestType.RERUN:
# We already have an existing Rerun request, so we can coalesce the new
# rerun request into the existing one.
coalesced_states = _coalesce_widget_states(
self._rerun_data.widget_states, new_data.widget_states
)
if new_data.fragment_id:
# This RERUN request corresponds to a new fragment run. We append
# the new fragment ID to the end of the current fragment_id_queue if
# it isn't already contained in it.
fragment_id_queue = [*self._rerun_data.fragment_id_queue]
if new_data.fragment_id not in fragment_id_queue:
fragment_id_queue.append(new_data.fragment_id)
elif new_data.fragment_id_queue:
# new_data contains a new fragment_id_queue, so we just use it.
fragment_id_queue = new_data.fragment_id_queue
else:
# Otherwise, this is a request to rerun the full script, so we want
# to clear out any fragments we have queued to run since they'll all
# be run with the full script anyway.
fragment_id_queue = []
self._rerun_data = RerunData(
query_string=new_data.query_string,
widget_states=coalesced_states,
page_script_hash=new_data.page_script_hash,
page_name=new_data.page_name,
fragment_id_queue=fragment_id_queue,
is_fragment_scoped_rerun=new_data.is_fragment_scoped_rerun,
is_auto_rerun=new_data.is_auto_rerun,
context_info=new_data.context_info,
)
return True
# We'll never get here
raise RuntimeError(f"Unrecognized ScriptRunnerState: {self._state}")
def on_scriptrunner_yield(self) -> ScriptRequest | None:
"""Called by the ScriptRunner when it's at a yield point.
If we have no request or a RERUN request corresponding to one or more fragments
(that is not a fragment-scoped rerun), return None.
If we have a (full script or fragment-scoped) RERUN request, return the request
and set our internal state to CONTINUE.
If we have a STOP request, return the request and remain stopped.
"""
if self._state == ScriptRequestType.CONTINUE or (
self._state == ScriptRequestType.RERUN
and _fragment_run_should_not_preempt_script(
self._rerun_data.fragment_id_queue,
self._rerun_data.is_fragment_scoped_rerun,
)
):
# We avoid taking the lock in the common cases described above. If a STOP or
# preempting RERUN request is received after we've taken this code path, it
# will be handled at the next `on_scriptrunner_yield`, or when
# `on_scriptrunner_ready` is called.
return None
with self._lock:
if self._state == ScriptRequestType.RERUN:
# We already made this check in the fast-path above but need to do so
# again in case our state changed while we were waiting on the lock.
if _fragment_run_should_not_preempt_script(
self._rerun_data.fragment_id_queue,
self._rerun_data.is_fragment_scoped_rerun,
):
return None
self._state = ScriptRequestType.CONTINUE
return ScriptRequest(ScriptRequestType.RERUN, self._rerun_data)
assert self._state == ScriptRequestType.STOP
return ScriptRequest(ScriptRequestType.STOP)
def on_scriptrunner_ready(self) -> ScriptRequest:
"""Called by the ScriptRunner when it's about to run its script for
the first time, and also after its script has successfully completed.
If we have a RERUN request, return the request and set
our internal state to CONTINUE.
If we have a STOP request or no request, set our internal state
to STOP.
"""
with self._lock:
if self._state == ScriptRequestType.RERUN:
self._state = ScriptRequestType.CONTINUE
return ScriptRequest(ScriptRequestType.RERUN, self._rerun_data)
# If we don't have a rerun request, unconditionally change our
# state to STOP.
self._state = ScriptRequestType.STOP
return ScriptRequest(ScriptRequestType.STOP)

View File

@@ -0,0 +1,288 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import collections
import contextlib
import contextvars
import threading
from collections import Counter
from dataclasses import dataclass, field
from typing import (
TYPE_CHECKING,
Callable,
Final,
Union,
)
from urllib import parse
from typing_extensions import TypeAlias
from streamlit.errors import (
NoSessionContext,
StreamlitAPIException,
StreamlitSetPageConfigMustBeFirstCommandError,
)
from streamlit.logger import get_logger
if TYPE_CHECKING:
from pathlib import Path
from streamlit.cursor import RunningCursor
from streamlit.proto.ClientState_pb2 import ContextInfo
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.proto.PageProfile_pb2 import Command
from streamlit.runtime.fragment import FragmentStorage
from streamlit.runtime.pages_manager import PagesManager
from streamlit.runtime.scriptrunner_utils.script_requests import ScriptRequests
from streamlit.runtime.state import SafeSessionState
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
_LOGGER: Final = get_logger(__name__)
UserInfo: TypeAlias = dict[str, Union[str, bool, None]]
# If true, it indicates that we are in a cached function that disallows the usage of
# widgets. Using contextvars to be thread-safe.
in_cached_function: contextvars.ContextVar[bool] = contextvars.ContextVar(
"in_cached_function", default=False
)
@dataclass
class ScriptRunContext:
"""A context object that contains data for a "script run" - that is,
data that's scoped to a single ScriptRunner execution (and therefore also
scoped to a single connected "session").
ScriptRunContext is used internally by virtually every `st.foo()` function.
It is accessed only from the script thread that's created by ScriptRunner,
or from app-created helper threads that have been "attached" to the
ScriptRunContext via `add_script_run_ctx`.
Streamlit code typically retrieves the active ScriptRunContext via the
`get_script_run_ctx` function.
"""
session_id: str
_enqueue: Callable[[ForwardMsg], None]
query_string: str
session_state: SafeSessionState
uploaded_file_mgr: UploadedFileManager
main_script_path: str
user_info: UserInfo
fragment_storage: FragmentStorage
pages_manager: PagesManager
context_info: ContextInfo | None = None
gather_usage_stats: bool = False
command_tracking_deactivated: bool = False
tracked_commands: list[Command] = field(default_factory=list)
tracked_commands_counter: Counter[str] = field(default_factory=collections.Counter)
_set_page_config_allowed: bool = True
_has_script_started: bool = False
widget_ids_this_run: set[str] = field(default_factory=set)
widget_user_keys_this_run: set[str] = field(default_factory=set)
form_ids_this_run: set[str] = field(default_factory=set)
cursors: dict[int, RunningCursor] = field(default_factory=dict)
script_requests: ScriptRequests | None = None
current_fragment_id: str | None = None
fragment_ids_this_run: list[str] | None = None
new_fragment_ids: set[str] = field(default_factory=set)
_active_script_hash: str = ""
# we allow only one dialog to be open at the same time
has_dialog_opened: bool = False
# TODO(willhuang1997): Remove this variable when experimental query params are removed
_experimental_query_params_used = False
_production_query_params_used = False
@property
def page_script_hash(self):
return self.pages_manager.current_page_script_hash
@property
def active_script_hash(self):
return self._active_script_hash
@property
def main_script_parent(self) -> Path:
return self.pages_manager.main_script_parent
@contextlib.contextmanager
def run_with_active_hash(self, page_hash: str):
original_page_hash = self._active_script_hash
self._active_script_hash = page_hash
try:
yield
finally:
# in the event of any exception, ensure we set the active hash back
self._active_script_hash = original_page_hash
def set_mpa_v2_page(self, page_script_hash: str):
self._active_script_hash = self.pages_manager.main_script_hash
self.pages_manager.set_current_page_script_hash(page_script_hash)
def reset(
self,
query_string: str = "",
page_script_hash: str = "",
fragment_ids_this_run: list[str] | None = None,
context_info: ContextInfo | None = None,
) -> None:
self.cursors = {}
self.widget_ids_this_run = set()
self.widget_user_keys_this_run = set()
self.form_ids_this_run = set()
self.query_string = query_string
self.context_info = context_info
self.pages_manager.set_current_page_script_hash(page_script_hash)
self._active_script_hash = self.pages_manager.main_script_hash
# Permit set_page_config when the ScriptRunContext is reused on a rerun
self._set_page_config_allowed = True
self._has_script_started = False
self.command_tracking_deactivated: bool = False
self.tracked_commands = []
self.tracked_commands_counter = collections.Counter()
self.current_fragment_id = None
self.current_fragment_delta_path: list[int] = []
self.fragment_ids_this_run = fragment_ids_this_run
self.new_fragment_ids = set()
self.has_dialog_opened = False
in_cached_function.set(False)
parsed_query_params = parse.parse_qs(query_string, keep_blank_values=True)
with self.session_state.query_params() as qp:
qp.clear_with_no_forward_msg()
for key, val in parsed_query_params.items():
if len(val) == 0:
qp.set_with_no_forward_msg(key, val="")
elif len(val) == 1:
qp.set_with_no_forward_msg(key, val=val[-1])
else:
qp.set_with_no_forward_msg(key, val)
def on_script_start(self) -> None:
self._has_script_started = True
def enqueue(self, msg: ForwardMsg) -> None:
"""Enqueue a ForwardMsg for this context's session."""
if msg.HasField("page_config_changed") and not self._set_page_config_allowed:
raise StreamlitSetPageConfigMustBeFirstCommandError()
# We want to disallow set_page config if one of the following occurs:
# - set_page_config was called on this message
# - The script has already started and a different st call occurs (a delta)
if msg.HasField("page_config_changed") or (
msg.HasField("delta") and self._has_script_started
):
self._set_page_config_allowed = False
msg.metadata.active_script_hash = self.active_script_hash
# Pass the message up to our associated ScriptRunner.
self._enqueue(msg)
def ensure_single_query_api_used(self):
if self._experimental_query_params_used and self._production_query_params_used:
raise StreamlitAPIException(
"Using `st.query_params` together with either `st.experimental_get_query_params` "
"or `st.experimental_set_query_params` is not supported. Please convert your app "
"to only use `st.query_params`"
)
def mark_experimental_query_params_used(self):
self._experimental_query_params_used = True
self.ensure_single_query_api_used()
def mark_production_query_params_used(self):
self._production_query_params_used = True
self.ensure_single_query_api_used()
SCRIPT_RUN_CONTEXT_ATTR_NAME: Final = "streamlit_script_run_ctx"
def add_script_run_ctx(
thread: threading.Thread | None = None, ctx: ScriptRunContext | None = None
):
"""Adds the current ScriptRunContext to a newly-created thread.
This should be called from this thread's parent thread,
before the new thread starts.
Parameters
----------
thread : threading.Thread
The thread to attach the current ScriptRunContext to.
ctx : ScriptRunContext or None
The ScriptRunContext to add, or None to use the current thread's
ScriptRunContext.
Returns
-------
threading.Thread
The same thread that was passed in, for chaining.
"""
if thread is None:
thread = threading.current_thread()
if ctx is None:
ctx = get_script_run_ctx()
if ctx is not None:
setattr(thread, SCRIPT_RUN_CONTEXT_ATTR_NAME, ctx)
return thread
def get_script_run_ctx(suppress_warning: bool = False) -> ScriptRunContext | None:
"""
Parameters
----------
suppress_warning : bool
If True, don't log a warning if there's no ScriptRunContext.
Returns
-------
ScriptRunContext | None
The current thread's ScriptRunContext, or None if it doesn't have one.
"""
thread = threading.current_thread()
ctx: ScriptRunContext | None = getattr(thread, SCRIPT_RUN_CONTEXT_ATTR_NAME, None)
if ctx is None and not suppress_warning:
# Only warn about a missing ScriptRunContext if suppress_warning is False, and
# we were started via `streamlit run`. Otherwise, the user is likely running a
# script "bare", and doesn't need to be warned about streamlit
# bits that are irrelevant when not connected to a session.
_LOGGER.warning(
"Thread '%s': missing ScriptRunContext! This warning can be ignored when "
"running in bare mode.",
thread.name,
)
return ctx
def enqueue_message(msg: ForwardMsg) -> None:
"""Enqueues a ForwardMsg proto to send to the app."""
ctx = get_script_run_ctx()
if ctx is None:
raise NoSessionContext()
if ctx.current_fragment_id and msg.WhichOneof("type") == "delta":
msg.delta.fragment_id = ctx.current_fragment_id
ctx.enqueue(msg)

View File

@@ -0,0 +1,531 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import threading
from collections.abc import ItemsView, Iterator, KeysView, Mapping, ValuesView
from copy import deepcopy
from typing import (
Any,
Callable,
Final,
NoReturn,
)
from blinker import Signal
import streamlit as st
import streamlit.watcher.path_watcher
from streamlit import runtime
from streamlit.errors import StreamlitSecretNotFoundError
from streamlit.logger import get_logger
_LOGGER: Final = get_logger(__name__)
class SecretErrorMessages:
"""SecretErrorMessages stores all error messages we use for secrets to allow customization for different environments.
For example Streamlit Cloud can customize the message to be different than the open source.
For internal use, may change in future releases without notice.
"""
def __init__(self):
self.missing_attr_message = lambda attr_name: (
f'st.secrets has no attribute "{attr_name}". '
f"Did you forget to add it to secrets.toml, mount it to secret directory, or the app settings on Streamlit Cloud? "
f"More info: https://docs.streamlit.io/deploy/streamlit-community-cloud/deploy-your-app/secrets-management"
)
self.missing_key_message = lambda key: (
f'st.secrets has no key "{key}". '
f"Did you forget to add it to secrets.toml, mount it to secret directory, or the app settings on Streamlit Cloud? "
f"More info: https://docs.streamlit.io/deploy/streamlit-community-cloud/deploy-your-app/secrets-management"
)
self.no_secrets_found = lambda file_paths: (
f"No secrets found. Valid paths for a secrets.toml file or secret directories are: {', '.join(file_paths)}"
)
self.error_parsing_file_at_path = (
lambda path, ex: f"Error parsing secrets file at {path}: {ex}"
)
self.subfolder_path_is_not_a_folder = lambda sub_folder_path: (
f"{sub_folder_path} is not a folder. "
"To use directory based secrets, mount every secret in a subfolder under the secret directory"
)
self.invalid_secret_path = lambda path: (
f"Invalid secrets path: {path}: path is not a .toml file or a directory"
)
def set_missing_attr_message(self, message: Callable[[str], str]) -> None:
"""Set the missing attribute error message."""
self.missing_attr_message = message
def set_missing_key_message(self, message: Callable[[str], str]) -> None:
"""Set the missing key error message."""
self.missing_key_message = message
def set_no_secrets_found_message(self, message: Callable[[list[str]], str]) -> None:
"""Set the no secrets found error message."""
self.no_secrets_found = message
def set_error_parsing_file_at_path_message(
self, message: Callable[[str, Exception], str]
) -> None:
"""Set the error parsing file at path error message."""
self.error_parsing_file_at_path = message
def set_subfolder_path_is_not_a_folder_message(
self, message: Callable[[str], str]
) -> None:
"""Set the subfolder path is not a folder error message."""
self.subfolder_path_is_not_a_folder = message
def set_invalid_secret_path_message(self, message: Callable[[str], str]) -> None:
"""Set the invalid secret path error message."""
self.invalid_secret_path = message
def get_missing_attr_message(self, attr_name: str) -> str:
"""Get the missing attribute error message."""
return self.missing_attr_message(attr_name)
def get_missing_key_message(self, key: str) -> str:
"""Get the missing key error message."""
return self.missing_key_message(key)
def get_no_secrets_found_message(self, file_paths: list[str]) -> str:
"""Get the no secrets found error message."""
return self.no_secrets_found(file_paths)
def get_error_parsing_file_at_path_message(self, path: str, ex: Exception) -> str:
"""Get the error parsing file at path error message."""
return self.error_parsing_file_at_path(path, ex)
def get_subfolder_path_is_not_a_folder_message(self, sub_folder_path: str) -> str:
"""Get the subfolder path is not a folder error message."""
return self.subfolder_path_is_not_a_folder(sub_folder_path)
def get_invalid_secret_path_message(self, path: str) -> str:
"""Get the invalid secret path error message."""
return self.invalid_secret_path(path)
secret_error_messages_singleton: Final = SecretErrorMessages()
def _convert_to_dict(obj: Mapping[str, Any] | AttrDict) -> dict[str, Any]:
"""Convert Mapping or AttrDict objects to dictionaries."""
if isinstance(obj, AttrDict):
return obj.to_dict()
return {k: v.to_dict() if isinstance(v, AttrDict) else v for k, v in obj.items()}
def _missing_attr_error_message(attr_name: str) -> str:
return secret_error_messages_singleton.get_missing_attr_message(attr_name)
def _missing_key_error_message(key: str) -> str:
return secret_error_messages_singleton.get_missing_key_message(key)
class AttrDict(Mapping[str, Any]):
"""We use AttrDict to wrap up dictionary values from secrets
to provide dot access to nested secrets.
"""
def __init__(self, value):
self.__dict__["__nested_secrets__"] = dict(value)
@staticmethod
def _maybe_wrap_in_attr_dict(value) -> Any:
if not isinstance(value, Mapping):
return value
else:
return AttrDict(value)
def __len__(self) -> int:
return len(self.__nested_secrets__)
def __iter__(self) -> Iterator[str]:
return iter(self.__nested_secrets__)
def __getitem__(self, key: str) -> Any:
try:
value = self.__nested_secrets__[key]
return self._maybe_wrap_in_attr_dict(value)
except KeyError:
raise KeyError(_missing_key_error_message(key))
def __getattr__(self, attr_name: str) -> Any:
try:
value = self.__nested_secrets__[attr_name]
return self._maybe_wrap_in_attr_dict(value)
except KeyError:
raise AttributeError(_missing_attr_error_message(attr_name))
def __repr__(self):
return repr(self.__nested_secrets__)
def __setitem__(self, key, value) -> NoReturn:
raise TypeError("Secrets does not support item assignment.")
def __setattr__(self, key, value) -> NoReturn:
raise TypeError("Secrets does not support attribute assignment.")
def to_dict(self) -> dict[str, Any]:
return deepcopy(self.__nested_secrets__)
class Secrets(Mapping[str, Any]):
"""A dict-like class that stores secrets.
Parses secrets.toml on-demand. Cannot be externally mutated.
Safe to use from multiple threads.
"""
def __init__(self):
# Our secrets dict.
self._secrets: Mapping[str, Any] | None = None
self._lock = threading.RLock()
self._file_watchers_installed = False
self.file_change_listener = Signal(
doc="Emitted when a `secrets.toml` file has been changed."
)
def load_if_toml_exists(self) -> bool:
"""Load secrets.toml files from disk if they exists. If none exist,
no exception will be raised. (If a file exists but is malformed,
an exception *will* be raised.).
Returns True if a secrets.toml file was successfully parsed, False otherwise.
Thread-safe.
"""
try:
self._parse()
return True
except StreamlitSecretNotFoundError:
# No secrets.toml files exist. That's fine.
return False
def set_suppress_print_error_on_exception(
self, suppress_print_error_on_exception: bool
) -> None:
"""Left in place for compatibility with integrations until integration
code can be updated.
"""
pass
def _reset(self) -> None:
"""Clear the secrets dictionary and remove any secrets that were
added to os.environ.
Thread-safe.
"""
with self._lock:
if self._secrets is None:
return
for k, v in self._secrets.items():
self._maybe_delete_environment_variable(k, v)
self._secrets = None
def _parse_toml_file(self, path: str) -> tuple[Mapping[str, Any], bool]:
"""Parse a TOML file and return the secrets as a dictionary."""
secrets = {}
found_secrets_file = False
try:
with open(path, encoding="utf-8") as f:
secrets_file_str = f.read()
found_secrets_file = True
except FileNotFoundError:
# the default config for secrets contains two paths. It's likely one of will not have secrets file.
return {}, False
try:
import toml
secrets.update(toml.loads(secrets_file_str))
except (TypeError, toml.TomlDecodeError) as ex:
msg = (
secret_error_messages_singleton.get_error_parsing_file_at_path_message(
path, ex
)
)
raise StreamlitSecretNotFoundError(msg) from ex
return secrets, found_secrets_file
def _parse_directory(self, path: str) -> tuple[Mapping[str, Any], bool]:
"""Parse a directory for secrets. Directory style can be used to support Kubernetes secrets that are mounted to folders.
Example structure:
- top_level_secret_folder
- user_pass_secret (folder)
- username (file), content: myuser
- password (file), content: mypassword
- my_plain_secret (folder)
- regular_secret (file), content: mysecret
See: https://kubernetes.io/docs/tasks/inject-data-application/distribute-credentials-secure/#create-a-pod-that-has-access-to-the-secret-data-through-a-volume
And: https://docs.snowflake.com/en/developer-guide/snowpark-container-services/additional-considerations-services-jobs#passing-secrets-in-local-container-files
"""
secrets: dict[str, Any] = {}
found_secrets_file = False
for dirname in os.listdir(path):
sub_folder_path = os.path.join(path, dirname)
if not os.path.isdir(sub_folder_path):
error_msg = secret_error_messages_singleton.get_subfolder_path_is_not_a_folder_message(
sub_folder_path
)
raise StreamlitSecretNotFoundError(error_msg)
sub_secrets = {}
for filename in os.listdir(sub_folder_path):
file_path = os.path.join(sub_folder_path, filename)
# ignore folders
if os.path.isdir(file_path):
continue
with open(file_path) as f:
sub_secrets[filename] = f.read().strip()
found_secrets_file = True
if len(sub_secrets) == 1:
# if there's just one file, collapse it so it's directly under `dirname`
secrets[dirname] = sub_secrets[list(sub_secrets.keys())[0]]
else:
secrets[dirname] = sub_secrets
return secrets, found_secrets_file
def _parse_file_path(self, path: str) -> tuple[Mapping[str, Any], bool]:
if path.endswith(".toml"):
return self._parse_toml_file(path)
if os.path.isdir(path):
return self._parse_directory(path)
error_msg = secret_error_messages_singleton.get_invalid_secret_path_message(
path
)
raise StreamlitSecretNotFoundError(error_msg)
def _parse(self) -> Mapping[str, Any]:
"""Parse our secrets.toml files if they're not already parsed.
This function is safe to call from multiple threads.
Parameters
----------
print_exceptions : bool
If True, then exceptions will be printed with `st.error` before
being re-raised.
Raises
------
StreamlitSecretNotFoundError
Raised if secrets.toml doesn't exist.
"""
# Avoid taking a lock for the common case where secrets are already
# loaded.
secrets = self._secrets
if secrets is not None:
return secrets
with self._lock:
if self._secrets is not None:
return self._secrets
secrets = {}
file_paths = st.config.get_option("secrets.files")
found_secrets_file = False
for path in file_paths:
path_secrets, found_secrets_file_in_path = self._parse_file_path(path)
found_secrets_file = found_secrets_file or found_secrets_file_in_path
secrets.update(path_secrets)
if not found_secrets_file:
error_msg = (
secret_error_messages_singleton.get_no_secrets_found_message(
file_paths
)
)
raise StreamlitSecretNotFoundError(error_msg)
for k, v in secrets.items():
self._maybe_set_environment_variable(k, v)
self._secrets = secrets
self._maybe_install_file_watchers()
return self._secrets
def to_dict(self) -> dict[str, Any]:
"""Converts the secrets store into a nested dictionary, where nested AttrDict objects are also converted into dictionaries."""
secrets = self._parse()
return _convert_to_dict(secrets)
@staticmethod
def _maybe_set_environment_variable(k: Any, v: Any) -> None:
"""Add the given key/value pair to os.environ if the value
is a string, int, or float.
"""
value_type = type(v)
if value_type in (str, int, float):
os.environ[k] = str(v)
@staticmethod
def _maybe_delete_environment_variable(k: Any, v: Any) -> None:
"""Remove the given key/value pair from os.environ if the value
is a string, int, or float.
"""
value_type = type(v)
if value_type in (str, int, float) and os.environ.get(k) == v:
del os.environ[k]
def _maybe_install_file_watchers(self) -> None:
with self._lock:
if self._file_watchers_installed:
return
file_paths = st.config.get_option("secrets.files")
for path in file_paths:
try:
if path.endswith(".toml"):
streamlit.watcher.path_watcher.watch_file(
path,
self._on_secrets_changed,
watcher_type="poll",
)
else:
streamlit.watcher.path_watcher.watch_dir(
path,
self._on_secrets_changed,
watcher_type="poll",
)
except FileNotFoundError:
# A user may only have one secrets.toml file defined, so we'd expect
# FileNotFoundErrors to be raised when attempting to install a
# watcher on the nonexistent ones.
pass
# We set file_watchers_installed to True even if the installation attempt
# failed to avoid repeatedly trying to install it.
self._file_watchers_installed = True
def _on_secrets_changed(self, changed_file_path) -> None:
with self._lock:
_LOGGER.debug("Secret path %s changed, reloading", changed_file_path)
self._reset()
self._parse()
# Emit a signal to notify receivers that the `secrets.toml` file
# has been changed.
self.file_change_listener.send()
def __getattr__(self, key: str) -> Any:
"""Return the value with the given key. If no such key
exists, raise an AttributeError.
Thread-safe.
"""
try:
value = self._parse()[key]
if not isinstance(value, Mapping):
return value
else:
return AttrDict(value)
# We add FileNotFoundError since __getattr__ is expected to only raise
# AttributeError. Without handling FileNotFoundError, unittests.mocks
# fails during mock creation on Python3.9
except (KeyError, FileNotFoundError):
raise AttributeError(_missing_attr_error_message(key))
def __getitem__(self, key: str) -> Any:
"""Return the value with the given key. If no such key
exists, raise a KeyError.
Thread-safe.
"""
try:
value = self._parse()[key]
if not isinstance(value, Mapping):
return value
else:
return AttrDict(value)
except KeyError:
raise KeyError(_missing_key_error_message(key))
def __setattr__(self, key: str, value: Any) -> None:
# Allow internal attributes to be set
if key in {
"_secrets",
"_lock",
"_file_watchers_installed",
"_suppress_print_error_on_exception",
"file_change_listener",
"load_if_toml_exists",
}:
super().__setattr__(key, value)
else:
raise TypeError("Secrets does not support attribute assignment.")
def __repr__(self) -> str:
# If the runtime is NOT initialized, it is a method call outside
# the streamlit app, so we avoid reading the secrets file as it may not exist.
# If the runtime is initialized, display the contents of the file and
# the file must already exist.
"""A string representation of the contents of the dict. Thread-safe."""
if not runtime.exists():
return f"{self.__class__.__name__}"
return repr(self._parse())
def __len__(self) -> int:
"""The number of entries in the dict. Thread-safe."""
return len(self._parse())
def has_key(self, k: str) -> bool:
"""True if the given key is in the dict. Thread-safe."""
return k in self._parse()
def keys(self) -> KeysView[str]:
"""A view of the keys in the dict. Thread-safe."""
return self._parse().keys()
def values(self) -> ValuesView[Any]:
"""A view of the values in the dict. Thread-safe."""
return self._parse().values()
def items(self) -> ItemsView[str, Any]:
"""A view of the key-value items in the dict. Thread-safe."""
return self._parse().items()
def __contains__(self, key: Any) -> bool:
"""True if the given key is in the dict. Thread-safe."""
return key in self._parse()
def __iter__(self) -> Iterator[str]:
"""An iterator over the keys in the dict. Thread-safe."""
return iter(self._parse())
secrets_singleton: Final = Secrets()

View File

@@ -0,0 +1,394 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Protocol, cast
if TYPE_CHECKING:
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime.app_session import AppSession
from streamlit.runtime.script_data import ScriptData
from streamlit.runtime.scriptrunner.script_cache import ScriptCache
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
class SessionClientDisconnectedError(Exception):
"""Raised by operations on a disconnected SessionClient."""
class SessionClient(Protocol):
"""Interface for sending data to a session's client."""
@abstractmethod
def write_forward_msg(self, msg: ForwardMsg) -> None:
"""Deliver a ForwardMsg to the client.
If the SessionClient has been disconnected, it should raise a
SessionClientDisconnectedError.
"""
raise NotImplementedError
@dataclass
class ActiveSessionInfo:
"""Type containing data related to an active session.
This type is nearly identical to SessionInfo. The difference is that when using it,
we are guaranteed that SessionClient is not None.
"""
client: SessionClient
session: AppSession
script_run_count: int = 0
@dataclass
class SessionInfo:
"""Type containing data related to an AppSession.
For each AppSession, the Runtime tracks that session's
script_run_count. This is used to track the age of messages in
the ForwardMsgCache.
"""
client: SessionClient | None
session: AppSession
script_run_count: int = 0
def is_active(self) -> bool:
return self.client is not None
def to_active(self) -> ActiveSessionInfo:
assert self.is_active(), "A SessionInfo with no client cannot be active!"
# NOTE: The cast here (rather than copying this SessionInfo's fields into a new
# ActiveSessionInfo) is important as the Runtime expects to be able to mutate
# what's returned from get_active_session_info to increment script_run_count.
return cast("ActiveSessionInfo", self)
class SessionStorageError(Exception):
"""Exception class for errors raised by SessionStorage.
The original error that causes a SessionStorageError to be (re)raised will generally
be an I/O error specific to the concrete SessionStorage implementation.
"""
class SessionStorage(Protocol):
@abstractmethod
def get(self, session_id: str) -> SessionInfo | None:
"""Return the SessionInfo corresponding to session_id, or None if one does not
exist.
Parameters
----------
session_id
The unique ID of the session being fetched.
Returns
-------
SessionInfo or None
Raises
------
SessionStorageError
Raised if an error occurs while attempting to fetch the session. This will
generally happen if there is an error with the underlying storage backend
(e.g. if we lose our connection to it).
"""
raise NotImplementedError
@abstractmethod
def save(self, session_info: SessionInfo) -> None:
"""Save the given session.
Parameters
----------
session_info
The SessionInfo being saved.
Raises
------
SessionStorageError
Raised if an error occurs while saving the given session.
"""
raise NotImplementedError
@abstractmethod
def delete(self, session_id: str) -> None:
"""Mark the session corresponding to session_id for deletion and stop tracking
it.
Note that:
- Calling delete on an ID corresponding to a nonexistent session is a no-op.
- Calling delete on an ID should cause the given session to no longer be
tracked by this SessionStorage, but exactly when and how the session's data
is eventually cleaned up is a detail left up to the implementation.
Parameters
----------
session_id
The unique ID of the session to delete.
Raises
------
SessionStorageError
Raised if an error occurs while attempting to delete the session.
"""
raise NotImplementedError
@abstractmethod
def list(self) -> list[SessionInfo]:
"""List all sessions tracked by this SessionStorage.
Returns
-------
List[SessionInfo]
Raises
------
SessionStorageError
Raised if an error occurs while attempting to list sessions.
"""
raise NotImplementedError
class SessionManager(Protocol):
"""SessionManagers are responsible for encapsulating all session lifecycle behavior
that the Streamlit Runtime may care about.
A SessionManager must define the following required methods:
- __init__
- connect_session
- close_session
- get_session_info
- list_sessions
SessionManager implementations may also choose to define the notions of active and
inactive sessions. The precise definitions of active/inactive are left to the
concrete implementation. SessionManagers that wish to differentiate between active
and inactive sessions should have the required methods listed above operate on *all*
sessions. Additionally, they should define the following methods for working with
active sessions:
- disconnect_session
- get_active_session_info
- is_active_session
- list_active_sessions
When active session-related methods are left undefined, their default
implementations are the naturally corresponding required methods.
The Runtime, unless there's a good reason to do otherwise, should generally work
with the active-session versions of a SessionManager's methods. There isn't currently
a need for us to be able to operate on inactive sessions stored in SessionStorage
outside of the SessionManager itself. However, it's highly likely that we'll
eventually have to do so, which is why the abstractions allow for this now.
Notes
-----
Threading: All SessionManager methods are *not* threadsafe -- they must be called
from the runtime's eventloop thread.
"""
@abstractmethod
def __init__(
self,
session_storage: SessionStorage,
uploaded_file_manager: UploadedFileManager,
script_cache: ScriptCache,
message_enqueued_callback: Callable[[], None] | None,
) -> None:
"""Initialize a SessionManager with the given SessionStorage.
Parameters
----------
session_storage
The SessionStorage instance backing this SessionManager.
uploaded_file_manager
Used to manage files uploaded by users via the Streamlit web client.
script_cache
ScriptCache instance. Caches user script bytecode.
message_enqueued_callback
A callback invoked after a message is enqueued to be sent to a web client.
"""
raise NotImplementedError
@abstractmethod
def connect_session(
self,
client: SessionClient,
script_data: ScriptData,
user_info: dict[str, str | bool | None],
existing_session_id: str | None = None,
session_id_override: str | None = None,
) -> str:
"""Create a new session or connect to an existing one.
Parameters
----------
client
A concrete SessionClient implementation for communicating with
the session's client.
script_data
Contains parameters related to running a script.
user_info
A dict that contains information about the session's user. For now,
it only (optionally) contains the user's email address.
{
"email": "example@example.com"
}
existing_session_id
The ID of an existing session to reconnect to. If one is not provided, a new
session is created. Note that whether a SessionManager supports reconnecting
to an existing session is left up to the concrete SessionManager
implementation. Those that do not support reconnection should simply ignore
this argument.
session_id_override
The ID to assign to a new session being created with this method. Setting
this can be useful when the service that a Streamlit Runtime is running in
wants to tie the lifecycle of a Streamlit session to some other session-like
object that it manages. Only one of existing_session_id and
session_id_override should be set.
Returns
-------
str
The session's unique string ID.
"""
raise NotImplementedError
@abstractmethod
def close_session(self, session_id: str) -> None:
"""Close and completely delete the session with the given id.
This function may be called multiple times for the same session,
which is not an error. (Subsequent calls just no-op.)
Parameters
----------
session_id
The session's unique ID.
"""
raise NotImplementedError
@abstractmethod
def get_session_info(self, session_id: str) -> SessionInfo | None:
"""Return the SessionInfo for the given id, or None if no such session
exists.
Parameters
----------
session_id
The session's unique ID.
Returns
-------
SessionInfo or None
"""
raise NotImplementedError
@abstractmethod
def list_sessions(self) -> list[SessionInfo]:
"""Return the SessionInfo for all sessions managed by this SessionManager.
Returns
-------
List[SessionInfo]
"""
raise NotImplementedError
def num_sessions(self) -> int:
"""Return the number of sessions tracked by this SessionManager.
Subclasses of SessionManager shouldn't provide their own implementation of this
method without a *very* good reason.
Returns
-------
int
"""
return len(self.list_sessions())
# NOTE: The following methods only need to be overwritten when a concrete
# SessionManager implementation has a notion of active vs inactive sessions.
# If left unimplemented in a subclass, the default implementations of these methods
# call corresponding SessionManager methods in a natural way.
def disconnect_session(self, session_id: str) -> None:
"""Disconnect the given session.
This method should be idempotent.
Parameters
----------
session_id
The session's unique ID.
"""
self.close_session(session_id)
def get_active_session_info(self, session_id: str) -> ActiveSessionInfo | None:
"""Return the ActiveSessionInfo for the given id, or None if either no such
session exists or the session is not active.
Parameters
----------
session_id
The active session's unique ID.
Returns
-------
ActiveSessionInfo or None
"""
session = self.get_session_info(session_id)
if session is None or not session.is_active():
return None
return session.to_active()
def is_active_session(self, session_id: str) -> bool:
"""Return True if the given session exists and is active, False otherwise.
Returns
-------
bool
"""
return self.get_active_session_info(session_id) is not None
def list_active_sessions(self) -> list[ActiveSessionInfo]:
"""Return the session info for all active sessions tracked by this SessionManager.
Returns
-------
List[ActiveSessionInfo]
"""
return [s.to_active() for s in self.list_sessions()]
def num_active_sessions(self) -> int:
"""Return the number of active sessions tracked by this SessionManager.
Subclasses of SessionManager shouldn't provide their own implementation of this
method without a *very* good reason.
Returns
-------
int
"""
return len(self.list_active_sessions())

View File

@@ -0,0 +1,41 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from streamlit.runtime.state.common import WidgetArgs, WidgetCallback, WidgetKwargs
from streamlit.runtime.state.query_params_proxy import QueryParamsProxy
from streamlit.runtime.state.safe_session_state import SafeSessionState
from streamlit.runtime.state.session_state import (
SCRIPT_RUN_WITHOUT_ERRORS_KEY,
SessionState,
SessionStateStatProvider,
)
from streamlit.runtime.state.session_state_proxy import (
SessionStateProxy,
get_session_state,
)
from streamlit.runtime.state.widgets import register_widget
__all__ = [
"WidgetArgs",
"WidgetCallback",
"WidgetKwargs",
"QueryParamsProxy",
"SafeSessionState",
"SCRIPT_RUN_WITHOUT_ERRORS_KEY",
"SessionState",
"SessionStateStatProvider",
"SessionStateProxy",
"get_session_state",
"register_widget",
]

View File

@@ -0,0 +1,191 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions and data structures shared by session_state.py and widgets.py."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import (
Any,
Callable,
Final,
Generic,
Literal,
TypeVar,
cast,
get_args,
)
from typing_extensions import TypeAlias, TypeGuard
from streamlit import util
from streamlit.errors import (
StreamlitAPIException,
)
GENERATED_ELEMENT_ID_PREFIX: Final = "$$ID"
TESTING_KEY = "$$STREAMLIT_INTERNAL_KEY_TESTING"
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
WidgetArgs: TypeAlias = tuple[Any, ...]
WidgetKwargs: TypeAlias = dict[str, Any]
WidgetCallback: TypeAlias = Callable[..., None]
# A deserializer receives the value from whatever field is set on the
# WidgetState proto, and returns a regular python value. A serializer
# receives a regular python value, and returns something suitable for
# a value field on WidgetState proto. They should be inverses.
WidgetDeserializer: TypeAlias = Callable[[Any, str], T]
WidgetSerializer: TypeAlias = Callable[[T], Any]
# The array value field names are part of the larger set of possible value
# field names. See the explanation for said set below. The message types
# associated with these fields are distinguished by storing data in a `data`
# field in their messages, meaning they need special treatment in certain
# circumstances. Hence, they need their own, dedicated, sub-type.
ArrayValueFieldName: TypeAlias = Literal[
"double_array_value",
"int_array_value",
"string_array_value",
]
# A frozenset containing the allowed values of the ArrayValueFieldName type.
# Useful for membership checking.
_ARRAY_VALUE_FIELD_NAMES: Final = frozenset(
cast(
"tuple[ArrayValueFieldName, ...]",
# NOTE: get_args is not recursive, so this only works as long as
# ArrayValueFieldName remains flat.
get_args(ArrayValueFieldName),
)
)
# These are the possible field names that can be set in the `value` oneof-field
# of the WidgetState message (schema found in .proto/WidgetStates.proto).
# We need these as a literal type to ensure correspondence with the protobuf
# schema in certain parts of the python code.
# TODO(harahu): It would be preferable if this type was automatically derived
# from the protobuf schema, rather than manually maintained. Not sure how to
# achieve that, though.
ValueFieldName: TypeAlias = Literal[
ArrayValueFieldName,
"arrow_value",
"bool_value",
"bytes_value",
"double_value",
"file_uploader_state_value",
"int_value",
"json_value",
"string_value",
"trigger_value",
"string_trigger_value",
"chat_input_value",
]
def is_array_value_field_name(obj: object) -> TypeGuard[ArrayValueFieldName]:
return obj in _ARRAY_VALUE_FIELD_NAMES
@dataclass(frozen=True)
class WidgetMetadata(Generic[T]):
"""Metadata associated with a single widget. Immutable."""
id: str
deserializer: WidgetDeserializer[T] = field(repr=False)
serializer: WidgetSerializer[T] = field(repr=False)
value_type: ValueFieldName
# An optional user-code callback invoked when the widget's value changes.
# Widget callbacks are called at the start of a script run, before the
# body of the script is executed.
callback: WidgetCallback | None = None
callback_args: WidgetArgs | None = None
callback_kwargs: WidgetKwargs | None = None
fragment_id: str | None = None
def __repr__(self) -> str:
return util.repr_(self)
@dataclass(frozen=True)
class RegisterWidgetResult(Generic[T_co]):
"""Result returned by the `register_widget` family of functions/methods.
Should be usable by widget code to determine what value to return, and
whether to update the UI.
Parameters
----------
value : T_co
The widget's current value, or, in cases where the true widget value
could not be determined, an appropriate fallback value.
This value should be returned by the widget call.
value_changed : bool
True if the widget's value is different from the value most recently
returned from the frontend.
Implies an update to the frontend is needed.
"""
value: T_co
value_changed: bool
@classmethod
def failure(
cls, deserializer: WidgetDeserializer[T_co]
) -> RegisterWidgetResult[T_co]:
"""The canonical way to construct a RegisterWidgetResult in cases
where the true widget value could not be determined.
"""
return cls(value=deserializer(None, ""), value_changed=False)
def user_key_from_element_id(element_id: str) -> str | None:
"""Return the user key portion of a element id, or None if the id does not
have a user key.
TODO This will incorrectly indicate no user key if the user actually provides
"None" as a key, but we can't avoid this kind of problem while storing the
string representation of the no-user-key sentinel as part of the element id.
"""
user_key: str | None = element_id.split("-", maxsplit=2)[-1]
return None if user_key == "None" else user_key
def is_element_id(key: str) -> bool:
"""True if the given session_state key has the structure of a element ID."""
return key.startswith(GENERATED_ELEMENT_ID_PREFIX)
def is_keyed_element_id(key: str) -> bool:
"""True if the given session_state key has the structure of a element ID
with a user_key.
"""
return is_element_id(key) and not key.endswith("-None")
def require_valid_user_key(key: str) -> None:
"""Raise an Exception if the given user_key is invalid."""
if is_element_id(key):
raise StreamlitAPIException(
f"Keys beginning with {GENERATED_ELEMENT_ID_PREFIX} are reserved."
)

View File

@@ -0,0 +1,205 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Iterable, Iterator, MutableMapping
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Final
from urllib import parse
from streamlit.errors import StreamlitAPIException
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx
if TYPE_CHECKING:
from _typeshed import SupportsKeysAndGetItem
EMBED_QUERY_PARAM: Final[str] = "embed"
EMBED_OPTIONS_QUERY_PARAM: Final[str] = "embed_options"
EMBED_QUERY_PARAMS_KEYS: Final[list[str]] = [
EMBED_QUERY_PARAM,
EMBED_OPTIONS_QUERY_PARAM,
]
@dataclass
class QueryParams(MutableMapping[str, str]):
"""A lightweight wrapper of a dict that sends forwardMsgs when state changes.
It stores str keys with str and List[str] values.
"""
_query_params: dict[str, list[str] | str] = field(default_factory=dict)
def __iter__(self) -> Iterator[str]:
self._ensure_single_query_api_used()
return iter(
key
for key in self._query_params.keys()
if key not in EMBED_QUERY_PARAMS_KEYS
)
def __getitem__(self, key: str) -> str:
"""Retrieves a value for a given key in query parameters.
Returns the last item in a list or an empty string if empty.
If the key is not present, raise KeyError.
"""
self._ensure_single_query_api_used()
try:
if key in EMBED_QUERY_PARAMS_KEYS:
raise KeyError(missing_key_error_message(key))
value = self._query_params[key]
if isinstance(value, list):
if len(value) == 0:
return ""
else:
# Return the last value to mimic Tornado's behavior
# https://www.tornadoweb.org/en/stable/web.html#tornado.web.RequestHandler.get_query_argument
return value[-1]
return value
except KeyError:
raise KeyError(missing_key_error_message(key))
def __setitem__(self, key: str, value: str | Iterable[str]) -> None:
self._ensure_single_query_api_used()
self.__set_item_internal(key, value)
self._send_query_param_msg()
def __set_item_internal(self, key: str, value: str | Iterable[str]) -> None:
if isinstance(value, dict):
raise StreamlitAPIException(
f"You cannot set a query params key `{key}` to a dictionary."
)
if key in EMBED_QUERY_PARAMS_KEYS:
raise StreamlitAPIException(
"Query param embed and embed_options (case-insensitive) cannot be set programmatically."
)
# Type checking users should handle the string serialization themselves
# We will accept any type for the list and serialize to str just in case
if isinstance(value, Iterable) and not isinstance(value, str):
self._query_params[key] = [str(item) for item in value]
else:
self._query_params[key] = str(value)
def __delitem__(self, key: str) -> None:
self._ensure_single_query_api_used()
try:
if key in EMBED_QUERY_PARAMS_KEYS:
raise KeyError(missing_key_error_message(key))
del self._query_params[key]
self._send_query_param_msg()
except KeyError:
raise KeyError(missing_key_error_message(key))
def update(
self,
other: Iterable[tuple[str, str | Iterable[str]]]
| SupportsKeysAndGetItem[str, str | Iterable[str]] = (),
/,
**kwds: str,
):
# This overrides the `update` provided by MutableMapping
# to ensure only one one ForwardMsg is sent.
self._ensure_single_query_api_used()
if hasattr(other, "keys") and hasattr(other, "__getitem__"):
for key in other.keys():
self.__set_item_internal(key, other[key])
else:
for key, value in other:
self.__set_item_internal(key, value)
for key, value in kwds.items():
self.__set_item_internal(key, value)
self._send_query_param_msg()
def get_all(self, key: str) -> list[str]:
self._ensure_single_query_api_used()
if key not in self._query_params or key in EMBED_QUERY_PARAMS_KEYS:
return []
value = self._query_params[key]
return value if isinstance(value, list) else [value]
def __len__(self) -> int:
self._ensure_single_query_api_used()
return len(
{key for key in self._query_params if key not in EMBED_QUERY_PARAMS_KEYS}
)
def __str__(self) -> str:
self._ensure_single_query_api_used()
return str(self._query_params)
def _send_query_param_msg(self) -> None:
ctx = get_script_run_ctx()
if ctx is None:
return
self._ensure_single_query_api_used()
msg = ForwardMsg()
msg.page_info_changed.query_string = parse.urlencode(
self._query_params, doseq=True
)
ctx.query_string = msg.page_info_changed.query_string
ctx.enqueue(msg)
def clear(self) -> None:
self._ensure_single_query_api_used()
self.clear_with_no_forward_msg(preserve_embed=True)
self._send_query_param_msg()
def to_dict(self) -> dict[str, str]:
self._ensure_single_query_api_used()
# return the last query param if multiple values are set
return {
key: self[key]
for key in self._query_params
if key not in EMBED_QUERY_PARAMS_KEYS
}
def from_dict(
self,
_dict: Iterable[tuple[str, str | Iterable[str]]]
| SupportsKeysAndGetItem[str, str | Iterable[str]],
):
self._ensure_single_query_api_used()
old_value = self._query_params.copy()
self.clear_with_no_forward_msg(preserve_embed=True)
try:
self.update(_dict)
except StreamlitAPIException:
# restore the original from before we made any changes.
self._query_params = old_value
raise
def set_with_no_forward_msg(self, key: str, val: list[str] | str) -> None:
self._query_params[key] = val
def clear_with_no_forward_msg(self, preserve_embed: bool = False) -> None:
self._query_params = {
key: value
for key, value in self._query_params.items()
if key in EMBED_QUERY_PARAMS_KEYS and preserve_embed
}
def _ensure_single_query_api_used(self):
ctx = get_script_run_ctx()
if ctx is None:
return
ctx.mark_production_query_params_used()
def missing_key_error_message(key: str) -> str:
return f'st.query_params has no key "{key}".'

View File

@@ -0,0 +1,218 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Iterable, Iterator, MutableMapping
from typing import TYPE_CHECKING, overload
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.state.session_state_proxy import get_session_state
if TYPE_CHECKING:
from _typeshed import SupportsKeysAndGetItem
class QueryParamsProxy(MutableMapping[str, str]):
"""
A stateless singleton that proxies ``st.query_params`` interactions
to the current script thread's QueryParams instance.
"""
def __iter__(self) -> Iterator[str]:
with get_session_state().query_params() as qp:
return iter(qp)
def __len__(self) -> int:
with get_session_state().query_params() as qp:
return len(qp)
def __str__(self) -> str:
with get_session_state().query_params() as qp:
return str(qp)
@gather_metrics("query_params.get_item")
def __getitem__(self, key: str) -> str:
with get_session_state().query_params() as qp:
try:
return qp[key]
except KeyError:
raise KeyError(self.missing_key_error_message(key))
def __delitem__(self, key: str) -> None:
with get_session_state().query_params() as qp:
del qp[key]
@gather_metrics("query_params.set_item")
def __setitem__(self, key: str, value: str | Iterable[str]) -> None:
with get_session_state().query_params() as qp:
qp[key] = value
@gather_metrics("query_params.get_attr")
def __getattr__(self, key: str) -> str:
with get_session_state().query_params() as qp:
try:
return qp[key]
except KeyError:
raise AttributeError(self.missing_attr_error_message(key))
def __delattr__(self, key: str) -> None:
with get_session_state().query_params() as qp:
try:
del qp[key]
except KeyError:
raise AttributeError(self.missing_key_error_message(key))
@overload
def update(
self, mapping: SupportsKeysAndGetItem[str, str | Iterable[str]], /, **kwds: str
) -> None: ...
@overload
def update(
self, keys_and_values: Iterable[tuple[str, str | Iterable[str]]], /, **kwds: str
) -> None: ...
@overload
def update(self, **kwds: str | Iterable[str]) -> None: ...
def update(self, other=(), /, **kwds):
"""
Update one or more values in query_params at once from a dictionary or
dictionary-like object.
See `Mapping.update()` from Python's `collections` library.
Parameters
----------
other: SupportsKeysAndGetItem[str, str] | Iterable[tuple[str, str]]
A dictionary or mapping of strings to strings.
**kwds: str
Additional key/value pairs to update passed as keyword arguments.
"""
with get_session_state().query_params() as qp:
qp.update(other, **kwds)
@gather_metrics("query_params.set_attr")
def __setattr__(self, key: str, value: str | Iterable[str]) -> None:
with get_session_state().query_params() as qp:
qp[key] = value
@gather_metrics("query_params.get_all")
def get_all(self, key: str) -> list[str]:
"""
Get a list of all query parameter values associated to a given key.
When a key is repeated as a query parameter within the URL, this method
allows all values to be obtained. In contrast, dict-like methods only
retrieve the last value when a key is repeated in the URL.
Parameters
----------
key: str
The label of the query parameter in the URL.
Returns
-------
List[str]
A list of values associated to the given key. May return zero, one,
or multiple values.
"""
with get_session_state().query_params() as qp:
return qp.get_all(key)
@gather_metrics("query_params.clear")
def clear(self) -> None:
"""
Clear all query parameters from the URL of the app.
Returns
-------
None
"""
with get_session_state().query_params() as qp:
qp.clear()
@gather_metrics("query_params.to_dict")
def to_dict(self) -> dict[str, str]:
"""
Get all query parameters as a dictionary.
This method primarily exists for internal use and is not needed for
most cases. ``st.query_params`` returns an object that inherits from
``dict`` by default.
When a key is repeated as a query parameter within the URL, this method
will return only the last value of each unique key.
Returns
-------
Dict[str,str]
A dictionary of the current query paramters in the app's URL.
"""
with get_session_state().query_params() as qp:
return qp.to_dict()
@overload
def from_dict(
self, keys_and_values: Iterable[tuple[str, str | Iterable[str]]]
) -> None: ...
@overload
def from_dict(
self, mapping: SupportsKeysAndGetItem[str, str | Iterable[str]]
) -> None: ...
@gather_metrics("query_params.from_dict")
def from_dict(self, params):
"""
Set all of the query parameters from a dictionary or dictionary-like object.
This method primarily exists for advanced users who want to control
multiple query parameters in a single update. To set individual query
parameters, use key or attribute notation instead.
This method inherits limitations from ``st.query_params`` and can't be
used to set embedding options as described in `Embed your app \
<https://docs.streamlit.io/deploy/streamlit-community-cloud/share-your-app/embed-your-app#embed-options>`_.
To handle repeated keys, the value in a key-value pair should be a list.
.. note::
``.from_dict()`` is not a direct inverse of ``.to_dict()`` if
you are working with repeated keys. A true inverse operation is
``{key: st.query_params.get_all(key) for key in st.query_params}``.
Parameters
----------
params: dict
A dictionary used to replace the current query parameters.
Example
-------
>>> import streamlit as st
>>>
>>> st.query_params.from_dict({"foo": "bar", "baz": [1, "two"]})
"""
with get_session_state().query_params() as qp:
return qp.from_dict(params)
@staticmethod
def missing_key_error_message(key: str) -> str:
return f'st.query_params has no key "{key}".'
@staticmethod
def missing_attr_error_message(key: str) -> str:
return f'st.query_params has no attribute "{key}".'

View File

@@ -0,0 +1,138 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import threading
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable
if TYPE_CHECKING:
from collections.abc import Iterator
from streamlit.proto.WidgetStates_pb2 import WidgetState as WidgetStateProto
from streamlit.proto.WidgetStates_pb2 import WidgetStates as WidgetStatesProto
from streamlit.runtime.state.common import RegisterWidgetResult, T, WidgetMetadata
from streamlit.runtime.state.query_params import QueryParams
from streamlit.runtime.state.session_state import SessionState
class SafeSessionState:
"""Thread-safe wrapper around SessionState.
When AppSession gets a re-run request, it can interrupt its existing
ScriptRunner and spin up a new ScriptRunner to handle the request.
When this happens, the existing ScriptRunner will continue executing
its script until it reaches a yield point - but during this time, it
must not mutate its SessionState.
"""
_state: SessionState
_lock: threading.RLock
_yield_callback: Callable[[], None]
def __init__(self, state: SessionState, yield_callback: Callable[[], None]):
# Fields must be set using the object's setattr method to avoid
# infinite recursion from trying to look up the fields we're setting.
object.__setattr__(self, "_state", state)
# TODO: we'd prefer this be a threading.Lock instead of RLock -
# but `call_callbacks` first needs to be rewritten.
object.__setattr__(self, "_lock", threading.RLock())
object.__setattr__(self, "_yield_callback", yield_callback)
def register_widget(
self, metadata: WidgetMetadata[T], user_key: str | None
) -> RegisterWidgetResult[T]:
self._yield_callback()
with self._lock:
return self._state.register_widget(metadata, user_key)
def on_script_will_rerun(self, latest_widget_states: WidgetStatesProto) -> None:
self._yield_callback()
with self._lock:
# TODO: rewrite this to copy the callbacks list into a local
# variable so that we don't need to hold our lock for the
# duration. (This will also allow us to downgrade our RLock
# to a Lock.)
self._state.on_script_will_rerun(latest_widget_states)
def on_script_finished(self, widget_ids_this_run: set[str]) -> None:
with self._lock:
self._state.on_script_finished(widget_ids_this_run)
def maybe_check_serializable(self) -> None:
with self._lock:
self._state.maybe_check_serializable()
def get_widget_states(self) -> list[WidgetStateProto]:
"""Return a list of serialized widget values for each widget with a value."""
with self._lock:
return self._state.get_widget_states()
def is_new_state_value(self, user_key: str) -> bool:
with self._lock:
return self._state.is_new_state_value(user_key)
@property
def filtered_state(self) -> dict[str, Any]:
"""The combined session and widget state, excluding keyless widgets."""
with self._lock:
return self._state.filtered_state
def __getitem__(self, key: str) -> Any:
self._yield_callback()
with self._lock:
return self._state[key]
def __setitem__(self, key: str, value: Any) -> None:
self._yield_callback()
with self._lock:
self._state[key] = value
def __delitem__(self, key: str) -> None:
self._yield_callback()
with self._lock:
del self._state[key]
def __contains__(self, key: str) -> bool:
self._yield_callback()
with self._lock:
return key in self._state
def __getattr__(self, key: str) -> Any:
try:
return self[key]
except KeyError:
raise AttributeError(f"{key} not found in session_state.")
def __setattr__(self, key: str, value: Any) -> None:
self[key] = value
def __delattr__(self, key: str) -> None:
try:
del self[key]
except KeyError:
raise AttributeError(f"{key} not found in session_state.")
def __repr__(self):
"""Presents itself as a simple dict of the underlying SessionState instance."""
kv = ((k, self._state[k]) for k in self._state._keys())
s = ", ".join(f"{k}: {v!r}" for k, v in kv)
return f"{{{s}}}"
@contextmanager
def query_params(self) -> Iterator[QueryParams]:
self._yield_callback()
with self._lock:
yield self._state.query_params

View File

@@ -0,0 +1,773 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
import pickle
from collections.abc import Iterator, KeysView, MutableMapping
from copy import deepcopy
from dataclasses import dataclass, field, replace
from typing import (
TYPE_CHECKING,
Any,
Final,
Union,
cast,
)
from typing_extensions import TypeAlias
import streamlit as st
from streamlit import config, util
from streamlit.errors import StreamlitAPIException, UnserializableSessionStateError
from streamlit.proto.WidgetStates_pb2 import WidgetState as WidgetStateProto
from streamlit.proto.WidgetStates_pb2 import WidgetStates as WidgetStatesProto
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx
from streamlit.runtime.state.common import (
RegisterWidgetResult,
T,
ValueFieldName,
WidgetMetadata,
is_array_value_field_name,
is_element_id,
is_keyed_element_id,
)
from streamlit.runtime.state.query_params import QueryParams
from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats
if TYPE_CHECKING:
from streamlit.runtime.session_manager import SessionManager
STREAMLIT_INTERNAL_KEY_PREFIX: Final = "$$STREAMLIT_INTERNAL_KEY"
SCRIPT_RUN_WITHOUT_ERRORS_KEY: Final = (
f"{STREAMLIT_INTERNAL_KEY_PREFIX}_SCRIPT_RUN_WITHOUT_ERRORS"
)
@dataclass(frozen=True)
class Serialized:
"""A widget value that's serialized to a protobuf. Immutable."""
value: WidgetStateProto
@dataclass(frozen=True)
class Value:
"""A widget value that's not serialized. Immutable."""
value: Any
WState: TypeAlias = Union[Value, Serialized]
@dataclass
class WStates(MutableMapping[str, Any]):
"""A mapping of widget IDs to values. Widget values can be stored in
serialized or deserialized form, but when values are retrieved from the
mapping, they'll always be deserialized.
"""
states: dict[str, WState] = field(default_factory=dict)
widget_metadata: dict[str, WidgetMetadata[Any]] = field(default_factory=dict)
def __repr__(self):
return util.repr_(self)
def __getitem__(self, k: str) -> Any:
"""Return the value of the widget with the given key.
If the widget's value is currently stored in serialized form, it
will be deserialized first.
"""
wstate = self.states.get(k)
if wstate is None:
raise KeyError(k)
if isinstance(wstate, Value):
# The widget's value is already deserialized - return it directly.
return wstate.value
# The widget's value is serialized. We deserialize it, and return
# the deserialized value.
metadata = self.widget_metadata.get(k)
if metadata is None:
# No deserializer, which should only happen if state is
# gotten from a reconnecting browser and the script is
# trying to access it. Pretend it doesn't exist.
raise KeyError(k)
value_field_name = cast(
"ValueFieldName",
wstate.value.WhichOneof("value"),
)
value = (
wstate.value.__getattribute__(value_field_name)
if value_field_name # Field name is None if the widget value was cleared
else None
)
if is_array_value_field_name(value_field_name):
# Array types are messages with data in a `data` field
value = value.data
elif value_field_name == "json_value":
value = json.loads(value)
deserialized = metadata.deserializer(value, metadata.id)
# Update metadata to reflect information from WidgetState proto
self.set_widget_metadata(
replace(
metadata,
value_type=value_field_name,
)
)
self.states[k] = Value(deserialized)
return deserialized
def __setitem__(self, k: str, v: WState) -> None:
self.states[k] = v
def __delitem__(self, k: str) -> None:
del self.states[k]
def __len__(self) -> int:
return len(self.states)
def __iter__(self):
# For this and many other methods, we can't simply delegate to the
# states field, because we need to invoke `__getitem__` for any
# values, to handle deserialization and unwrapping of values.
yield from self.states
def keys(self) -> KeysView[str]:
return KeysView(self.states)
def items(self) -> set[tuple[str, Any]]: # type: ignore[override]
return {(k, self[k]) for k in self}
def values(self) -> set[Any]: # type: ignore[override]
return {self[wid] for wid in self}
def update(self, other: WStates) -> None: # type: ignore[override]
"""Copy all widget values and metadata from 'other' into this mapping,
overwriting any data in this mapping that's also present in 'other'.
"""
self.states.update(other.states)
self.widget_metadata.update(other.widget_metadata)
def set_widget_from_proto(self, widget_state: WidgetStateProto) -> None:
"""Set a widget's serialized value, overwriting any existing value it has."""
self[widget_state.id] = Serialized(widget_state)
def set_from_value(self, k: str, v: Any) -> None:
"""Set a widget's deserialized value, overwriting any existing value it has."""
self[k] = Value(v)
def set_widget_metadata(self, widget_meta: WidgetMetadata[Any]) -> None:
"""Set a widget's metadata, overwriting any existing metadata it has."""
self.widget_metadata[widget_meta.id] = widget_meta
def remove_stale_widgets(
self,
active_widget_ids: set[str],
fragment_ids_this_run: list[str] | None,
) -> None:
"""Remove widget state for stale widgets."""
self.states = {
k: v
for k, v in self.states.items()
if not _is_stale_widget(
self.widget_metadata.get(k),
active_widget_ids,
fragment_ids_this_run,
)
}
def get_serialized(self, k: str) -> WidgetStateProto | None:
"""Get the serialized value of the widget with the given id.
If the widget doesn't exist, return None. If the widget exists but
is not in serialized form, it will be serialized first.
"""
item = self.states.get(k)
if item is None:
# No such widget: return None.
return None
if isinstance(item, Serialized):
# Widget value is serialized: return it directly.
return item.value
# Widget value is not serialized: serialize it first!
metadata = self.widget_metadata.get(k)
if metadata is None:
# We're missing the widget's metadata. (Can this happen?)
return None
widget = WidgetStateProto()
widget.id = k
field = metadata.value_type
serialized = metadata.serializer(item.value)
if is_array_value_field_name(field):
arr = getattr(widget, field)
arr.data.extend(serialized)
elif field == "json_value":
setattr(widget, field, json.dumps(serialized))
elif field == "file_uploader_state_value":
widget.file_uploader_state_value.CopyFrom(serialized)
elif field == "string_trigger_value":
widget.string_trigger_value.CopyFrom(serialized)
elif field == "chat_input_value":
widget.chat_input_value.CopyFrom(serialized)
elif field is not None and serialized is not None:
# If the field is None, the widget value was cleared
# by the user and therefore is None. But we cannot
# set it to None here, since the proto properties are
# not nullable. So we just don't set it.
setattr(widget, field, serialized)
return widget
def as_widget_states(self) -> list[WidgetStateProto]:
"""Return a list of serialized widget values for each widget with a value."""
states = [
self.get_serialized(widget_id)
for widget_id in self.states.keys()
if self.get_serialized(widget_id)
]
states = cast("list[WidgetStateProto]", states)
return states
def call_callback(self, widget_id: str) -> None:
"""Call the given widget's callback and return the callback's
return value. If the widget has no callback, return None.
If the widget doesn't exist, raise an Exception.
"""
metadata = self.widget_metadata.get(widget_id)
assert metadata is not None
callback = metadata.callback
if callback is None:
return
args = metadata.callback_args or ()
kwargs = metadata.callback_kwargs or {}
callback(*args, **kwargs)
def _missing_key_error_message(key: str) -> str:
return (
f'st.session_state has no key "{key}". Did you forget to initialize it? '
f"More info: https://docs.streamlit.io/develop/concepts/architecture/session-state#initialization"
)
@dataclass
class KeyIdMapper:
"""A mapping of user-provided keys to element IDs.
It also maps element IDs to user-provided keys so that this reverse mapping
does not have to be computed ad-hoc.
All built-in dict-operations such as setting and deleting expect the key as the
argument, not the element ID.
"""
_key_id_mapping: dict[str, str] = field(default_factory=dict)
_id_key_mapping: dict[str, str] = field(default_factory=dict)
def __contains__(self, key: str) -> bool:
return key in self._key_id_mapping
def __setitem__(self, key: str, widget_id: Any) -> None:
self._key_id_mapping[key] = widget_id
self._id_key_mapping[widget_id] = key
def __delitem__(self, key: str) -> None:
self.delete(key)
@property
def id_key_mapping(self) -> dict[str, str]:
return self._id_key_mapping
def set_key_id_mapping(self, key_id_mapping: dict[str, str]) -> None:
self._key_id_mapping = key_id_mapping
self._id_key_mapping = {v: k for k, v in key_id_mapping.items()}
def get_id_from_key(self, key: str, default: Any = None) -> str:
return self._key_id_mapping.get(key, default)
def get_key_from_id(self, widget_id: str) -> str:
return self._id_key_mapping[widget_id]
def update(self, other: KeyIdMapper) -> None:
self._key_id_mapping.update(other._key_id_mapping)
self._id_key_mapping.update(other._id_key_mapping)
def clear(self):
self._key_id_mapping.clear()
self._id_key_mapping.clear()
def delete(self, key: str):
widget_id = self._key_id_mapping[key]
del self._key_id_mapping[key]
del self._id_key_mapping[widget_id]
@dataclass
class SessionState:
"""SessionState allows users to store values that persist between app
reruns.
Example
-------
>>> if "num_script_runs" not in st.session_state:
... st.session_state.num_script_runs = 0
>>> st.session_state.num_script_runs += 1
>>> st.write(st.session_state.num_script_runs) # writes 1
The next time your script runs, the value of
st.session_state.num_script_runs will be preserved.
>>> st.session_state.num_script_runs += 1
>>> st.write(st.session_state.num_script_runs) # writes 2
"""
# All the values from previous script runs, squished together to save memory
_old_state: dict[str, Any] = field(default_factory=dict)
# Values set in session state during the current script run, possibly for
# setting a widget's value. Keyed by a user provided string.
_new_session_state: dict[str, Any] = field(default_factory=dict)
# Widget values from the frontend, usually one changing prompted the script rerun
_new_widget_state: WStates = field(default_factory=WStates)
# Keys used for widgets will be eagerly converted to the matching element id
_key_id_mapper: KeyIdMapper = field(default_factory=KeyIdMapper)
# query params are stored in session state because query params will be tied with
# widget state at one point.
query_params: QueryParams = field(default_factory=QueryParams)
def __repr__(self):
return util.repr_(self)
# is it possible for a value to get through this without being deserialized?
def _compact_state(self) -> None:
"""Copy all current session_state and widget_state values into our
_old_state dict, and then clear our current session_state and
widget_state.
"""
for key_or_wid in self:
try:
self._old_state[key_or_wid] = self[key_or_wid]
except KeyError:
# handle key errors from widget state not having metadata gracefully
# https://github.com/streamlit/streamlit/issues/7206
pass
self._new_session_state.clear()
self._new_widget_state.clear()
def clear(self) -> None:
"""Reset self completely, clearing all current and old values."""
self._old_state.clear()
self._new_session_state.clear()
self._new_widget_state.clear()
self._key_id_mapper.clear()
@property
def filtered_state(self) -> dict[str, Any]:
"""The combined session and widget state, excluding keyless widgets."""
wid_key_map = self._key_id_mapper.id_key_mapping
state: dict[str, Any] = {}
# We can't write `for k, v in self.items()` here because doing so will
# run into a `KeyError` if widget metadata has been cleared (which
# happens when the streamlit server restarted or the cache was cleared),
# then we receive a widget's state from a browser.
for k in self._keys():
if not is_element_id(k) and not _is_internal_key(k):
state[k] = self[k]
elif is_keyed_element_id(k):
try:
key = wid_key_map[k]
state[key] = self[k]
except KeyError:
# Widget id no longer maps to a key, it is a not yet
# cleared value in old state for a reset widget
pass
return state
def _keys(self) -> set[str]:
"""All keys active in Session State, with widget keys converted
to widget ids when one is known. (This includes autogenerated keys
for widgets that don't have user_keys defined, and which aren't
exposed to user code).
"""
old_keys = {self._get_widget_id(k) for k in self._old_state.keys()}
new_widget_keys = set(self._new_widget_state.keys())
new_session_state_keys = {
self._get_widget_id(k) for k in self._new_session_state.keys()
}
return old_keys | new_widget_keys | new_session_state_keys
def is_new_state_value(self, user_key: str) -> bool:
"""True if a value with the given key is in the current session state."""
return user_key in self._new_session_state
def __iter__(self) -> Iterator[Any]:
"""Return an iterator over the keys of the SessionState.
This is a shortcut for `iter(self.keys())`.
"""
return iter(self._keys())
def __len__(self) -> int:
"""Return the number of items in SessionState."""
return len(self._keys())
def __getitem__(self, key: str) -> Any:
wid_key_map = self._key_id_mapper.id_key_mapping
widget_id = self._get_widget_id(key)
if widget_id in wid_key_map and widget_id == key:
# the "key" is a raw widget id, so get its associated user key for lookup
key = wid_key_map[widget_id]
try:
return self._getitem(widget_id, key)
except KeyError:
raise KeyError(_missing_key_error_message(key))
def _getitem(self, widget_id: str | None, user_key: str | None) -> Any:
"""Get the value of an entry in Session State, using either the
user-provided key or a widget id as appropriate for the internal dict
being accessed.
At least one of the arguments must have a value.
"""
assert user_key is not None or widget_id is not None
if user_key is not None:
try:
return self._new_session_state[user_key]
except KeyError:
pass
if widget_id is not None:
try:
return self._new_widget_state[widget_id]
except KeyError:
pass
# Typically, there won't be both a widget id and an associated state key in
# old state at the same time, so the order we check is arbitrary.
# The exception is if session state is set and then a later run has
# a widget created, so the widget id entry should be newer.
# The opposite case shouldn't happen, because setting the value of a widget
# through session state will result in the next widget state reflecting that
# value.
if widget_id is not None:
try:
return self._old_state[widget_id]
except KeyError:
pass
if user_key is not None:
try:
return self._old_state[user_key]
except KeyError:
pass
# We'll never get here
raise KeyError
def __setitem__(self, user_key: str, value: Any) -> None:
"""Set the value of the session_state entry with the given user_key.
If the key corresponds to a widget or form that's been instantiated
during the current script run, raise a StreamlitAPIException instead.
"""
ctx = get_script_run_ctx()
if ctx is not None:
widget_id = self._key_id_mapper.get_id_from_key(user_key, None)
widget_ids = ctx.widget_ids_this_run
form_ids = ctx.form_ids_this_run
if widget_id in widget_ids or user_key in form_ids:
raise StreamlitAPIException(
f"`st.session_state.{user_key}` cannot be modified after the widget"
f" with key `{user_key}` is instantiated."
)
self._new_session_state[user_key] = value
def __delitem__(self, key: str) -> None:
widget_id = self._get_widget_id(key)
if not (key in self or widget_id in self):
raise KeyError(_missing_key_error_message(key))
if key in self._new_session_state:
del self._new_session_state[key]
if key in self._old_state:
del self._old_state[key]
if key in self._key_id_mapper:
self._key_id_mapper.delete(key)
if widget_id in self._new_widget_state:
del self._new_widget_state[widget_id]
if widget_id in self._old_state:
del self._old_state[widget_id]
def set_widgets_from_proto(self, widget_states: WidgetStatesProto) -> None:
"""Set the value of all widgets represented in the given WidgetStatesProto."""
for state in widget_states.widgets:
self._new_widget_state.set_widget_from_proto(state)
def on_script_will_rerun(self, latest_widget_states: WidgetStatesProto) -> None:
"""Called by ScriptRunner before its script re-runs.
Update widget data and call callbacks on widgets whose value changed
between the previous and current script runs.
"""
# Clear any triggers that weren't reset because the script was disconnected
self._reset_triggers()
self._compact_state()
self.set_widgets_from_proto(latest_widget_states)
self._call_callbacks()
def _call_callbacks(self) -> None:
"""Call any callback associated with each widget whose value
changed between the previous and current script runs.
"""
from streamlit.runtime.scriptrunner import RerunException
changed_widget_ids = [
wid for wid in self._new_widget_state if self._widget_changed(wid)
]
for wid in changed_widget_ids:
try:
self._new_widget_state.call_callback(wid)
except RerunException:
st.warning("Calling st.rerun() within a callback is a no-op.")
def _widget_changed(self, widget_id: str) -> bool:
"""True if the given widget's value changed between the previous
script run and the current script run.
"""
new_value = self._new_widget_state.get(widget_id)
old_value = self._old_state.get(widget_id)
changed: bool = new_value != old_value
return changed
def on_script_finished(self, widget_ids_this_run: set[str]) -> None:
"""Called by ScriptRunner after its script finishes running.
Updates widgets to prepare for the next script run.
Parameters
----------
widget_ids_this_run: set[str]
The IDs of the widgets that were accessed during the script
run. Any widget state whose ID does *not* appear in this set
is considered "stale" and will be removed.
"""
self._reset_triggers()
self._remove_stale_widgets(widget_ids_this_run)
def _reset_triggers(self) -> None:
"""Set all trigger values in our state dictionary to False."""
for state_id in self._new_widget_state:
metadata = self._new_widget_state.widget_metadata.get(state_id)
if metadata is not None:
if metadata.value_type == "trigger_value":
self._new_widget_state[state_id] = Value(False)
elif metadata.value_type == "string_trigger_value":
self._new_widget_state[state_id] = Value(None)
elif metadata.value_type == "chat_input_value":
self._new_widget_state[state_id] = Value(None)
for state_id in self._old_state:
metadata = self._new_widget_state.widget_metadata.get(state_id)
if metadata is not None:
if metadata.value_type == "trigger_value":
self._old_state[state_id] = False
elif metadata.value_type == "string_trigger_value":
self._old_state[state_id] = None
elif metadata.value_type == "chat_input_value":
self._old_state[state_id] = None
def _remove_stale_widgets(self, active_widget_ids: set[str]) -> None:
"""Remove widget state for widgets whose ids aren't in `active_widget_ids`."""
ctx = get_script_run_ctx()
if ctx is None:
return
self._new_widget_state.remove_stale_widgets(
active_widget_ids,
ctx.fragment_ids_this_run,
)
# Remove entries from _old_state corresponding to
# widgets not in widget_ids.
self._old_state = {
k: v
for k, v in self._old_state.items()
if (
not is_element_id(k)
or not _is_stale_widget(
self._new_widget_state.widget_metadata.get(k),
active_widget_ids,
ctx.fragment_ids_this_run,
)
)
}
def _set_widget_metadata(self, widget_metadata: WidgetMetadata[Any]) -> None:
"""Set a widget's metadata."""
widget_id = widget_metadata.id
self._new_widget_state.widget_metadata[widget_id] = widget_metadata
def get_widget_states(self) -> list[WidgetStateProto]:
"""Return a list of serialized widget values for each widget with a value."""
return self._new_widget_state.as_widget_states()
def _get_widget_id(self, k: str) -> str:
"""Turns a value that might be a widget id or a user provided key into
an appropriate widget id.
"""
return self._key_id_mapper.get_id_from_key(k, k)
def _set_key_widget_mapping(self, widget_id: str, user_key: str) -> None:
self._key_id_mapper[user_key] = widget_id
def register_widget(
self, metadata: WidgetMetadata[T], user_key: str | None
) -> RegisterWidgetResult[T]:
"""Register a widget with the SessionState.
Returns
-------
RegisterWidgetResult[T]
Contains the widget's current value, and a bool that will be True
if the frontend needs to be updated with the current value.
"""
widget_id = metadata.id
self._set_widget_metadata(metadata)
if user_key is not None:
# If the widget has a user_key, update its user_key:widget_id mapping
self._set_key_widget_mapping(widget_id, user_key)
if widget_id not in self and (user_key is None or user_key not in self):
# This is the first time the widget is registered, so we save its
# value in widget state.
deserializer = metadata.deserializer
initial_widget_value = deepcopy(deserializer(None, metadata.id))
self._new_widget_state.set_from_value(widget_id, initial_widget_value)
# Get the current value of the widget for use as its return value.
# We return a copy, so that reference types can't be accidentally
# mutated by user code.
widget_value = cast("T", self[widget_id])
widget_value = deepcopy(widget_value)
# widget_value_changed indicates to the caller that the widget's
# current value is different from what is in the frontend.
widget_value_changed = user_key is not None and self.is_new_state_value(
user_key
)
return RegisterWidgetResult(widget_value, widget_value_changed)
def __contains__(self, key: str) -> bool:
try:
self[key]
except KeyError:
return False
else:
return True
def get_stats(self) -> list[CacheStat]:
# Lazy-load vendored package to prevent import of numpy
from streamlit.vendor.pympler.asizeof import asizeof
stat = CacheStat("st_session_state", "", asizeof(self))
return [stat]
def _check_serializable(self) -> None:
"""Verify that everything added to session state can be serialized.
We use pickleability as the metric for serializability, and test for
pickleability by just trying it.
"""
for k in self:
try:
pickle.dumps(self[k])
except Exception as e:
err_msg = f"""Cannot serialize the value (of type `{type(self[k])}`) of '{k}' in st.session_state.
Streamlit has been configured to use [pickle](https://docs.python.org/3/library/pickle.html) to
serialize session_state values. Please convert the value to a pickle-serializable type. To learn
more about this behavior, see [our docs](https://docs.streamlit.io/knowledge-base/using-streamlit/serializable-session-state). """
raise UnserializableSessionStateError(err_msg) from e
def maybe_check_serializable(self) -> None:
"""Verify that session state can be serialized, if the relevant config
option is set.
See `_check_serializable` for details.
"""
if config.get_option("runner.enforceSerializableSessionState"):
self._check_serializable()
def _is_internal_key(key: str) -> bool:
return key.startswith(STREAMLIT_INTERNAL_KEY_PREFIX)
def _is_stale_widget(
metadata: WidgetMetadata[Any] | None,
active_widget_ids: set[str],
fragment_ids_this_run: list[str] | None,
) -> bool:
if not metadata:
return True
elif metadata.id in active_widget_ids:
return False
# If we're running 1 or more fragments, but this widget is unrelated to any of the
# fragments that we're running, then it should not be marked as stale as its value
# may still be needed for a future fragment run or full script run.
elif fragment_ids_this_run and metadata.fragment_id not in fragment_ids_this_run:
return False
return True
@dataclass
class SessionStateStatProvider(CacheStatsProvider):
_session_mgr: SessionManager
def get_stats(self) -> list[CacheStat]:
stats: list[CacheStat] = []
for session_info in self._session_mgr.list_active_sessions():
session_state = session_info.session.session_state
stats.extend(session_state.get_stats())
return group_stats(stats)

View File

@@ -0,0 +1,153 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Iterator, MutableMapping
from typing import Any, Final
from streamlit import logger as _logger
from streamlit import runtime
from streamlit.elements.lib.utils import Key
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.state.common import require_valid_user_key
from streamlit.runtime.state.safe_session_state import SafeSessionState
from streamlit.runtime.state.session_state import SessionState
_LOGGER: Final = _logger.get_logger(__name__)
_state_use_warning_already_displayed: bool = False
# The mock session state is used as a fallback if the script is run without `streamlit run`
_mock_session_state: SafeSessionState | None = None
def get_session_state() -> SafeSessionState:
"""Get the SessionState object for the current session.
Note that in streamlit scripts, this function should not be called
directly. Instead, SessionState objects should be accessed via
st.session_state.
"""
global _state_use_warning_already_displayed
from streamlit.runtime.scriptrunner_utils.script_run_context import (
get_script_run_ctx,
)
ctx = get_script_run_ctx()
# If there is no script run context because the script is run bare, we
# use a global mock session state version to allow bare script execution (via python script.py)
if ctx is None:
if not _state_use_warning_already_displayed:
_state_use_warning_already_displayed = True
if not runtime.exists():
_LOGGER.warning(
"Session state does not function when running a script without `streamlit run`"
)
global _mock_session_state
if _mock_session_state is None:
# Lazy initialize the mock session state
_mock_session_state = SafeSessionState(SessionState(), lambda: None)
return _mock_session_state
return ctx.session_state
class SessionStateProxy(MutableMapping[Key, Any]):
"""A stateless singleton that proxies `st.session_state` interactions
to the current script thread's SessionState instance.
The proxy API differs slightly from SessionState: it does not allow
callers to get, set, or iterate over "keyless" widgets (that is, widgets
that were created without a user_key, and have autogenerated keys).
"""
def __iter__(self) -> Iterator[Any]:
"""Iterator over user state and keyed widget values."""
# TODO: this is unsafe if fastReruns is true! Let's deprecate/remove.
return iter(get_session_state().filtered_state)
def __len__(self) -> int:
"""Number of user state and keyed widget values in session_state."""
return len(get_session_state().filtered_state)
def __str__(self) -> str:
"""String representation of user state and keyed widget values."""
return str(get_session_state().filtered_state)
def __getitem__(self, key: Key) -> Any:
"""Return the state or widget value with the given key.
Raises
------
StreamlitAPIException
If the key is not a valid SessionState user key.
"""
key = str(key)
require_valid_user_key(key)
return get_session_state()[key]
@gather_metrics("session_state.set_item")
def __setitem__(self, key: Key, value: Any) -> None:
"""Set the value of the given key.
Raises
------
StreamlitAPIException
If the key is not a valid SessionState user key.
"""
key = str(key)
require_valid_user_key(key)
get_session_state()[key] = value
def __delitem__(self, key: Key) -> None:
"""Delete the value with the given key.
Raises
------
StreamlitAPIException
If the key is not a valid SessionState user key.
"""
key = str(key)
require_valid_user_key(key)
del get_session_state()[key]
def __getattr__(self, key: str) -> Any:
try:
return self[key]
except KeyError:
raise AttributeError(_missing_attr_error_message(key))
@gather_metrics("session_state.set_attr")
def __setattr__(self, key: str, value: Any) -> None:
self[key] = value
def __delattr__(self, key: str) -> None:
try:
del self[key]
except KeyError:
raise AttributeError(_missing_attr_error_message(key))
def to_dict(self) -> dict[str, Any]:
"""Return a dict containing all session_state and keyed widget values."""
return get_session_state().filtered_state
def _missing_attr_error_message(attr_name: str) -> str:
return (
f'st.session_state has no attribute "{attr_name}". Did you forget to initialize it? '
f"More info: https://docs.streamlit.io/develop/concepts/architecture/session-state#initialization"
)

View File

@@ -0,0 +1,135 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from streamlit.runtime.state.common import (
RegisterWidgetResult,
T,
ValueFieldName,
WidgetArgs,
WidgetCallback,
WidgetDeserializer,
WidgetKwargs,
WidgetMetadata,
WidgetSerializer,
user_key_from_element_id,
)
if TYPE_CHECKING:
from streamlit.runtime.scriptrunner import ScriptRunContext
def register_widget(
element_id: str,
*,
deserializer: WidgetDeserializer[T],
serializer: WidgetSerializer[T],
ctx: ScriptRunContext | None,
on_change_handler: WidgetCallback | None = None,
args: WidgetArgs | None = None,
kwargs: WidgetKwargs | None = None,
value_type: ValueFieldName,
) -> RegisterWidgetResult[T]:
"""Register a widget with Streamlit, and return its current value.
NOTE: This function should be called after the proto has been filled.
Parameters
----------
element_id : str
The id of the element. Must be unique.
deserializer : WidgetDeserializer[T]
Called to convert a widget's protobuf value to the value returned by
its st.<widget_name> function.
serializer : WidgetSerializer[T]
Called to convert a widget's value to its protobuf representation.
ctx : ScriptRunContext or None
Used to ensure uniqueness of widget IDs, and to look up widget values.
on_change_handler : WidgetCallback or None
An optional callback invoked when the widget's value changes.
args : WidgetArgs or None
args to pass to on_change_handler when invoked
kwargs : WidgetKwargs or None
kwargs to pass to on_change_handler when invoked
value_type: ValueType
The value_type the widget is going to use.
We use this information to start with a best-effort guess for the value_type
of each widget. Once we actually receive a proto for a widget from the
frontend, the guess is updated to be the correct type. Unfortunately, we're
not able to always rely on the proto as the type may be needed earlier.
Thankfully, in these cases (when value_type == "trigger_value"), the static
table here being slightly inaccurate should never pose a problem.
Returns
-------
register_widget_result : RegisterWidgetResult[T]
Provides information on which value to return to the widget caller,
and whether the UI needs updating.
- Unhappy path:
- Our ScriptRunContext doesn't exist (meaning that we're running
as a "bare script" outside streamlit).
- We are disconnected from the SessionState instance.
In both cases we'll return a fallback RegisterWidgetResult[T].
- Happy path:
- The widget has already been registered on a previous run but the
user hasn't interacted with it on the client. The widget will have
the default value it was first created with. We then return a
RegisterWidgetResult[T], containing this value.
- The widget has already been registered and the user *has*
interacted with it. The widget will have that most recent
user-specified value. We then return a RegisterWidgetResult[T],
containing this value.
For both paths a widget return value is provided, allowing the widgets
to be used in a non-streamlit setting.
"""
# Create the widget's updated metadata, and register it with session_state.
metadata = WidgetMetadata(
element_id,
deserializer,
serializer,
value_type=value_type,
callback=on_change_handler,
callback_args=args,
callback_kwargs=kwargs,
fragment_id=ctx.current_fragment_id if ctx else None,
)
return register_widget_from_metadata(metadata, ctx)
def register_widget_from_metadata(
metadata: WidgetMetadata[T],
ctx: ScriptRunContext | None,
) -> RegisterWidgetResult[T]:
"""Register a widget and return its value, using an already constructed
`WidgetMetadata`.
This is split out from `register_widget` to allow caching code to replay
widgets by saving and reusing the completed metadata.
See `register_widget` for details on what this returns.
"""
if ctx is None:
# Early-out if we don't have a script run context (which probably means
# we're running as a "bare" Python script, and not via `streamlit run`).
return RegisterWidgetResult.failure(deserializer=metadata.deserializer)
widget_id = metadata.id
user_key = user_key_from_element_id(widget_id)
return ctx.session_state.register_widget(metadata, user_key)

View File

@@ -0,0 +1,109 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import itertools
from abc import abstractmethod
from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable
if TYPE_CHECKING:
from streamlit.proto.openmetrics_data_model_pb2 import Metric as MetricProto
class CacheStat(NamedTuple):
"""Describes a single cache entry.
Properties
----------
category_name : str
A human-readable name for the cache "category" that the entry belongs
to - e.g. "st.memo", "session_state", etc.
cache_name : str
A human-readable name for cache instance that the entry belongs to.
For "st.memo" and other function decorator caches, this might be the
name of the cached function. If the cache category doesn't have
multiple separate cache instances, this can just be the empty string.
byte_length : int
The entry's memory footprint in bytes.
"""
category_name: str
cache_name: str
byte_length: int
def to_metric_str(self) -> str:
return f'cache_memory_bytes{{cache_type="{self.category_name}",cache="{self.cache_name}"}} {self.byte_length}'
def marshall_metric_proto(self, metric: MetricProto) -> None:
"""Fill an OpenMetrics `Metric` protobuf object."""
label = metric.labels.add()
label.name = "cache_type"
label.value = self.category_name
label = metric.labels.add()
label.name = "cache"
label.value = self.cache_name
metric_point = metric.metric_points.add()
metric_point.gauge_value.int_value = self.byte_length
def group_stats(stats: list[CacheStat]) -> list[CacheStat]:
"""Group a list of CacheStats by category_name and cache_name and sum byte_length."""
def key_function(individual_stat):
return individual_stat.category_name, individual_stat.cache_name
result: list[CacheStat] = []
sorted_stats = sorted(stats, key=key_function)
grouped_stats = itertools.groupby(sorted_stats, key=key_function)
for (category_name, cache_name), single_group_stats in grouped_stats:
result.append(
CacheStat(
category_name=category_name,
cache_name=cache_name,
byte_length=sum(item.byte_length for item in single_group_stats),
)
)
return result
@runtime_checkable
class CacheStatsProvider(Protocol):
@abstractmethod
def get_stats(self) -> list[CacheStat]:
raise NotImplementedError
class StatsManager:
def __init__(self):
self._cache_stats_providers: list[CacheStatsProvider] = []
def register_provider(self, provider: CacheStatsProvider) -> None:
"""Register a CacheStatsProvider with the manager.
This function is not thread-safe. Call it immediately after
creation.
"""
self._cache_stats_providers.append(provider)
def get_stats(self) -> list[CacheStat]:
"""Return a list containing all stats from each registered provider."""
all_stats: list[CacheStat] = []
for provider in self._cache_stats_providers:
all_stats.extend(provider.get_stats())
return all_stats

View File

@@ -0,0 +1,149 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import io
from abc import abstractmethod
from typing import TYPE_CHECKING, NamedTuple, Protocol
from streamlit import util
from streamlit.runtime.stats import CacheStatsProvider
if TYPE_CHECKING:
from collections.abc import Sequence
from streamlit.proto.Common_pb2 import FileURLs as FileURLsProto
class UploadedFileRec(NamedTuple):
"""Metadata and raw bytes for an uploaded file. Immutable."""
file_id: str
name: str
type: str
data: bytes
class UploadFileUrlInfo(NamedTuple):
"""Information we provide for single file in get_upload_urls."""
file_id: str
upload_url: str
delete_url: str
class DeletedFile(NamedTuple):
"""Represents a deleted file in deserialized values for st.file_uploader and
st.camera_input.
Return this from st.file_uploader and st.camera_input deserialize (so they can
be used in session_state), when widget value contains file record that is missing
from the storage.
DeleteFile instances filtered out before return final value to the user in script,
or before sending to frontend.
"""
file_id: str
class UploadedFile(io.BytesIO):
"""A mutable uploaded file.
This class extends BytesIO, which has copy-on-write semantics when
initialized with `bytes`.
"""
def __init__(self, record: UploadedFileRec, file_urls: FileURLsProto):
# BytesIO's copy-on-write semantics doesn't seem to be mentioned in
# the Python docs - possibly because it's a CPython-only optimization
# and not guaranteed to be in other Python runtimes. But it's detailed
# here: https://hg.python.org/cpython/rev/79a5fbe2c78f
super().__init__(record.data)
self.file_id = record.file_id
self.name = record.name
self.type = record.type
self.size = len(record.data)
self._file_urls = file_urls
def __eq__(self, other: object) -> bool:
if not isinstance(other, UploadedFile):
return NotImplemented
return self.file_id == other.file_id
def __repr__(self) -> str:
return util.repr_(self)
class UploadedFileManager(CacheStatsProvider, Protocol):
"""UploadedFileManager protocol, that should be implemented by the concrete
uploaded file managers.
It is responsible for:
- retrieving files by session_id and file_id for st.file_uploader and
st.camera_input
- cleaning up uploaded files associated with session on session end
It should be created during Runtime initialization.
Optionally UploadedFileManager could be responsible for issuing URLs which will be
used by frontend to upload files to.
"""
@abstractmethod
def get_files(
self, session_id: str, file_ids: Sequence[str]
) -> list[UploadedFileRec]:
"""Return a list of UploadedFileRec for a given sequence of file_ids.
Parameters
----------
session_id
The ID of the session that owns the files.
file_ids
The sequence of ids associated with files to retrieve.
Returns
-------
List[UploadedFileRec]
A list of URL UploadedFileRec instances, each instance contains information
about uploaded file.
"""
raise NotImplementedError
@abstractmethod
def remove_session_files(self, session_id: str) -> None:
"""Remove all files associated with a given session."""
raise NotImplementedError
def get_upload_urls(
self, session_id: str, file_names: Sequence[str]
) -> list[UploadFileUrlInfo]:
"""Return a list of UploadFileUrlInfo for a given sequence of file_names.
Optional to implement, issuing of URLs could be done by other service.
Parameters
----------
session_id
The ID of the session that request URLs.
file_names
The sequence of file names for which URLs are requested
Returns
-------
List[UploadFileUrlInfo]
A list of UploadFileUrlInfo instances, each instance contains information
about uploaded file URLs.
"""
raise NotImplementedError

View File

@@ -0,0 +1,167 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, Callable, Final, cast
from streamlit.logger import get_logger
from streamlit.runtime.app_session import AppSession
from streamlit.runtime.session_manager import (
ActiveSessionInfo,
SessionClient,
SessionInfo,
SessionManager,
SessionStorage,
)
if TYPE_CHECKING:
from streamlit.runtime.script_data import ScriptData
from streamlit.runtime.scriptrunner.script_cache import ScriptCache
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
_LOGGER: Final = get_logger(__name__)
class WebsocketSessionManager(SessionManager):
"""A SessionManager used to manage sessions with lifecycles tied to those of a
browser tab's websocket connection.
WebsocketSessionManagers differentiate between "active" and "inactive" sessions.
Active sessions are those with a currently active websocket connection. Inactive
sessions are sessions without. Eventual cleanup of inactive sessions is a detail left
to the specific SessionStorage that a WebsocketSessionManager is instantiated with.
"""
def __init__(
self,
session_storage: SessionStorage,
uploaded_file_manager: UploadedFileManager,
script_cache: ScriptCache,
message_enqueued_callback: Callable[[], None] | None,
) -> None:
self._session_storage = session_storage
self._uploaded_file_mgr = uploaded_file_manager
self._script_cache = script_cache
self._message_enqueued_callback = message_enqueued_callback
# Mapping of AppSession.id -> ActiveSessionInfo.
self._active_session_info_by_id: dict[str, ActiveSessionInfo] = {}
def connect_session(
self,
client: SessionClient,
script_data: ScriptData,
user_info: dict[str, str | bool | None],
existing_session_id: str | None = None,
session_id_override: str | None = None,
) -> str:
assert not (existing_session_id and session_id_override), (
"Only one of existing_session_id and session_id_override should be truthy"
)
if existing_session_id in self._active_session_info_by_id:
_LOGGER.warning(
"Session with id %s is already connected! Connecting to a new session.",
existing_session_id,
)
session_info = (
existing_session_id
and existing_session_id not in self._active_session_info_by_id
and self._session_storage.get(existing_session_id)
)
if session_info:
existing_session = session_info.session
existing_session.register_file_watchers()
self._active_session_info_by_id[existing_session.id] = ActiveSessionInfo(
client,
existing_session,
session_info.script_run_count,
)
self._session_storage.delete(existing_session.id)
return existing_session.id
session = AppSession(
script_data=script_data,
uploaded_file_manager=self._uploaded_file_mgr,
script_cache=self._script_cache,
message_enqueued_callback=self._message_enqueued_callback,
user_info=user_info,
session_id_override=session_id_override,
)
_LOGGER.debug(
"Created new session for client %s. Session ID: %s", id(client), session.id
)
assert session.id not in self._active_session_info_by_id, (
f"session.id '{session.id}' registered multiple times!"
)
self._active_session_info_by_id[session.id] = ActiveSessionInfo(client, session)
return session.id
def disconnect_session(self, session_id: str) -> None:
if session_id in self._active_session_info_by_id:
active_session_info = self._active_session_info_by_id[session_id]
session = active_session_info.session
session.request_script_stop()
session.disconnect_file_watchers()
self._session_storage.save(
SessionInfo(
client=None,
session=session,
script_run_count=active_session_info.script_run_count,
)
)
del self._active_session_info_by_id[session_id]
def get_active_session_info(self, session_id: str) -> ActiveSessionInfo | None:
return self._active_session_info_by_id.get(session_id)
def is_active_session(self, session_id: str) -> bool:
return session_id in self._active_session_info_by_id
def list_active_sessions(self) -> list[ActiveSessionInfo]:
return list(self._active_session_info_by_id.values())
def close_session(self, session_id: str) -> None:
if session_id in self._active_session_info_by_id:
active_session_info = self._active_session_info_by_id[session_id]
del self._active_session_info_by_id[session_id]
active_session_info.session.shutdown()
return
session_info = self._session_storage.get(session_id)
if session_info:
self._session_storage.delete(session_id)
session_info.session.shutdown()
def get_session_info(self, session_id: str) -> SessionInfo | None:
session_info = self.get_active_session_info(session_id)
if session_info:
return cast("SessionInfo", session_info)
return self._session_storage.get(session_id)
def list_sessions(self) -> list[SessionInfo]:
return (
cast("list[SessionInfo]", self.list_active_sessions())
+ self._session_storage.list()
)