Mise à jour de Monitor.py et autres scripts
This commit is contained in:
@@ -0,0 +1,50 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from streamlit.runtime.runtime import Runtime, RuntimeConfig, RuntimeState
|
||||
from streamlit.runtime.session_manager import (
|
||||
SessionClient,
|
||||
SessionClientDisconnectedError,
|
||||
)
|
||||
|
||||
|
||||
def get_instance() -> Runtime:
|
||||
"""Return the singleton Runtime instance. Raise an Error if the
|
||||
Runtime hasn't been created yet.
|
||||
"""
|
||||
return Runtime.instance()
|
||||
|
||||
|
||||
def exists() -> bool:
|
||||
"""True if the singleton Runtime instance has been created.
|
||||
|
||||
When a Streamlit app is running in "raw mode" - that is, when the
|
||||
app is run via `python app.py` instead of `streamlit run app.py` -
|
||||
the Runtime will not exist, and various Streamlit functions need
|
||||
to adapt.
|
||||
"""
|
||||
return Runtime.exists()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Runtime",
|
||||
"RuntimeConfig",
|
||||
"RuntimeState",
|
||||
"SessionClient",
|
||||
"SessionClientDisconnectedError",
|
||||
"get_instance",
|
||||
"exists",
|
||||
]
|
||||
@@ -0,0 +1,985 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Callable, Final
|
||||
|
||||
from google.protobuf.json_format import ParseDict
|
||||
|
||||
import streamlit.elements.exception as exception_utils
|
||||
from streamlit import config, runtime
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.proto.ClientState_pb2 import ClientState
|
||||
from streamlit.proto.Common_pb2 import FileURLs, FileURLsRequest
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
from streamlit.proto.GitInfo_pb2 import GitInfo
|
||||
from streamlit.proto.NewSession_pb2 import (
|
||||
Config,
|
||||
CustomThemeConfig,
|
||||
FontFace,
|
||||
NewSession,
|
||||
UserInfo,
|
||||
)
|
||||
from streamlit.runtime import caching
|
||||
from streamlit.runtime.forward_msg_queue import ForwardMsgQueue
|
||||
from streamlit.runtime.fragment import FragmentStorage, MemoryFragmentStorage
|
||||
from streamlit.runtime.metrics_util import Installation
|
||||
from streamlit.runtime.pages_manager import PagesManager
|
||||
from streamlit.runtime.scriptrunner import RerunData, ScriptRunner, ScriptRunnerEvent
|
||||
from streamlit.runtime.secrets import secrets_singleton
|
||||
from streamlit.string_util import to_snake_case
|
||||
from streamlit.version import STREAMLIT_VERSION_STRING
|
||||
from streamlit.watcher import LocalSourcesWatcher
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.proto.BackMsg_pb2 import BackMsg
|
||||
from streamlit.runtime.script_data import ScriptData
|
||||
from streamlit.runtime.scriptrunner.script_cache import ScriptCache
|
||||
from streamlit.runtime.state import SessionState
|
||||
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
|
||||
from streamlit.source_util import PageHash, PageInfo
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
|
||||
class AppSessionState(Enum):
|
||||
APP_NOT_RUNNING = "APP_NOT_RUNNING"
|
||||
APP_IS_RUNNING = "APP_IS_RUNNING"
|
||||
SHUTDOWN_REQUESTED = "SHUTDOWN_REQUESTED"
|
||||
|
||||
|
||||
def _generate_scriptrun_id() -> str:
|
||||
"""Randomly generate a unique ID for a script execution."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class AppSession:
|
||||
"""
|
||||
Contains session data for a single "user" of an active app
|
||||
(that is, a connected browser tab).
|
||||
|
||||
Each AppSession has its own ScriptData, root DeltaGenerator, ScriptRunner,
|
||||
and widget state.
|
||||
|
||||
An AppSession is attached to each thread involved in running its script.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
script_data: ScriptData,
|
||||
uploaded_file_manager: UploadedFileManager,
|
||||
script_cache: ScriptCache,
|
||||
message_enqueued_callback: Callable[[], None] | None,
|
||||
user_info: dict[str, str | bool | None],
|
||||
session_id_override: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the AppSession.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
script_data
|
||||
Object storing parameters related to running a script
|
||||
|
||||
uploaded_file_manager
|
||||
Used to manage files uploaded by users via the Streamlit web client.
|
||||
|
||||
script_cache
|
||||
The app's ScriptCache instance. Stores cached user scripts. ScriptRunner
|
||||
uses the ScriptCache to avoid having to reload user scripts from disk
|
||||
on each rerun.
|
||||
|
||||
message_enqueued_callback
|
||||
After enqueuing a message, this callable notification will be invoked.
|
||||
|
||||
user_info
|
||||
A dict that contains information about the current user. For now,
|
||||
it only contains the user's email address.
|
||||
|
||||
{
|
||||
"email": "example@example.com"
|
||||
}
|
||||
|
||||
Information about the current user is optionally provided when a
|
||||
websocket connection is initialized via the "X-Streamlit-User" header.
|
||||
|
||||
session_id_override
|
||||
The ID to assign to this session. Setting this can be useful when the
|
||||
service that a Streamlit Runtime is running in wants to tie the lifecycle of
|
||||
a Streamlit session to some other session-like object that it manages.
|
||||
"""
|
||||
|
||||
# Each AppSession has a unique string ID.
|
||||
self.id = session_id_override or str(uuid.uuid4())
|
||||
|
||||
self._event_loop = asyncio.get_running_loop()
|
||||
self._script_data = script_data
|
||||
self._uploaded_file_mgr = uploaded_file_manager
|
||||
self._script_cache = script_cache
|
||||
self._pages_manager = PagesManager(
|
||||
script_data.main_script_path, self._script_cache
|
||||
)
|
||||
|
||||
# The browser queue contains messages that haven't yet been
|
||||
# delivered to the browser. Periodically, the server flushes
|
||||
# this queue and delivers its contents to the browser.
|
||||
self._browser_queue = ForwardMsgQueue()
|
||||
self._message_enqueued_callback = message_enqueued_callback
|
||||
|
||||
self._state = AppSessionState.APP_NOT_RUNNING
|
||||
|
||||
# Need to remember the client state here because when a script reruns
|
||||
# due to the source code changing we need to pass in the previous client state.
|
||||
self._client_state = ClientState()
|
||||
|
||||
self._local_sources_watcher: LocalSourcesWatcher | None = None
|
||||
self._stop_config_listener: Callable[[], bool] | None = None
|
||||
self._stop_pages_listener: Callable[[], None] | None = None
|
||||
|
||||
if config.get_option("server.fileWatcherType") != "none":
|
||||
self.register_file_watchers()
|
||||
|
||||
self._run_on_save = config.get_option("server.runOnSave")
|
||||
|
||||
self._scriptrunner: ScriptRunner | None = None
|
||||
|
||||
# This needs to be lazily imported to avoid a dependency cycle.
|
||||
from streamlit.runtime.state import SessionState
|
||||
|
||||
self._session_state = SessionState()
|
||||
self._user_info = user_info
|
||||
|
||||
self._debug_last_backmsg_id: str | None = None
|
||||
|
||||
self._fragment_storage: FragmentStorage = MemoryFragmentStorage()
|
||||
|
||||
_LOGGER.debug("AppSession initialized (id=%s)", self.id)
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Ensure that we call shutdown() when an AppSession is garbage collected."""
|
||||
self.shutdown()
|
||||
|
||||
def register_file_watchers(self) -> None:
|
||||
"""Register handlers to be called when various files are changed.
|
||||
|
||||
Files that we watch include:
|
||||
- source files that already exist (for edits)
|
||||
- `.py` files in the the main script's `pages/` directory (for file additions
|
||||
and deletions)
|
||||
- project and user-level config.toml files
|
||||
- the project-level secrets.toml files
|
||||
|
||||
This method is called automatically on AppSession construction, but it may be
|
||||
called again in the case when a session is disconnected and is being reconnect
|
||||
to.
|
||||
"""
|
||||
if self._local_sources_watcher is None:
|
||||
self._local_sources_watcher = LocalSourcesWatcher(self._pages_manager)
|
||||
|
||||
self._local_sources_watcher.register_file_change_callback(
|
||||
self._on_source_file_changed
|
||||
)
|
||||
self._stop_config_listener = config.on_config_parsed(
|
||||
self._on_source_file_changed, force_connect=True
|
||||
)
|
||||
secrets_singleton.file_change_listener.connect(self._on_secrets_file_changed)
|
||||
|
||||
def disconnect_file_watchers(self) -> None:
|
||||
"""Disconnect the file watcher handlers registered by register_file_watchers."""
|
||||
if self._local_sources_watcher is not None:
|
||||
self._local_sources_watcher.close()
|
||||
if self._stop_config_listener is not None:
|
||||
self._stop_config_listener()
|
||||
if self._stop_pages_listener is not None:
|
||||
self._stop_pages_listener()
|
||||
|
||||
secrets_singleton.file_change_listener.disconnect(self._on_secrets_file_changed)
|
||||
|
||||
self._local_sources_watcher = None
|
||||
self._stop_config_listener = None
|
||||
self._stop_pages_listener = None
|
||||
|
||||
def flush_browser_queue(self) -> list[ForwardMsg]:
|
||||
"""Clear the forward message queue and return the messages it contained.
|
||||
|
||||
The Server calls this periodically to deliver new messages
|
||||
to the browser connected to this app.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[ForwardMsg]
|
||||
The messages that were removed from the queue and should
|
||||
be delivered to the browser.
|
||||
|
||||
"""
|
||||
return self._browser_queue.flush()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shut down the AppSession.
|
||||
|
||||
It's an error to use a AppSession after it's been shut down.
|
||||
|
||||
"""
|
||||
if self._state != AppSessionState.SHUTDOWN_REQUESTED:
|
||||
_LOGGER.debug("Shutting down (id=%s)", self.id)
|
||||
# Clear any unused session files in upload file manager and media
|
||||
# file manager
|
||||
self._uploaded_file_mgr.remove_session_files(self.id)
|
||||
|
||||
if runtime.exists():
|
||||
rt = runtime.get_instance()
|
||||
rt.media_file_mgr.clear_session_refs(self.id)
|
||||
rt.media_file_mgr.remove_orphaned_files()
|
||||
|
||||
# Shut down the ScriptRunner, if one is active.
|
||||
# self._state must not be set to SHUTDOWN_REQUESTED until
|
||||
# *after* this is called.
|
||||
self.request_script_stop()
|
||||
|
||||
self._state = AppSessionState.SHUTDOWN_REQUESTED
|
||||
|
||||
# Disconnect all file watchers if we haven't already, although we will have
|
||||
# generally already done so by the time we get here.
|
||||
self.disconnect_file_watchers()
|
||||
|
||||
def _enqueue_forward_msg(self, msg: ForwardMsg) -> None:
|
||||
"""Enqueue a new ForwardMsg to our browser queue.
|
||||
|
||||
This can be called on both the main thread and a ScriptRunner
|
||||
run thread.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg : ForwardMsg
|
||||
The message to enqueue
|
||||
|
||||
"""
|
||||
|
||||
if self._debug_last_backmsg_id:
|
||||
msg.debug_last_backmsg_id = self._debug_last_backmsg_id
|
||||
|
||||
self._browser_queue.enqueue(msg)
|
||||
if self._message_enqueued_callback:
|
||||
self._message_enqueued_callback()
|
||||
|
||||
def handle_backmsg(self, msg: BackMsg) -> None:
|
||||
"""Process a BackMsg."""
|
||||
try:
|
||||
msg_type = msg.WhichOneof("type")
|
||||
if msg_type == "rerun_script":
|
||||
if msg.debug_last_backmsg_id:
|
||||
self._debug_last_backmsg_id = msg.debug_last_backmsg_id
|
||||
|
||||
self._handle_rerun_script_request(msg.rerun_script)
|
||||
elif msg_type == "load_git_info":
|
||||
self._handle_git_information_request()
|
||||
elif msg_type == "clear_cache":
|
||||
self._handle_clear_cache_request()
|
||||
elif msg_type == "app_heartbeat":
|
||||
self._handle_app_heartbeat_request()
|
||||
elif msg_type == "set_run_on_save":
|
||||
self._handle_set_run_on_save_request(msg.set_run_on_save)
|
||||
elif msg_type == "stop_script":
|
||||
self._handle_stop_script_request()
|
||||
elif msg_type == "file_urls_request":
|
||||
self._handle_file_urls_request(msg.file_urls_request)
|
||||
else:
|
||||
_LOGGER.warning('No handler for "%s"', msg_type)
|
||||
|
||||
except Exception as ex:
|
||||
_LOGGER.exception("Error processing back message")
|
||||
self.handle_backmsg_exception(ex)
|
||||
|
||||
def handle_backmsg_exception(self, e: BaseException) -> None:
|
||||
"""Handle an Exception raised while processing a BackMsg from the browser."""
|
||||
# This does a few things:
|
||||
# 1) Clears the current app in the browser.
|
||||
# 2) Marks the current app as "stopped" in the browser.
|
||||
# 3) HACK: Resets any script params that may have been broken (e.g. the
|
||||
# command-line when rerunning with wrong argv[0])
|
||||
|
||||
self._on_scriptrunner_event(
|
||||
self._scriptrunner, ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
|
||||
)
|
||||
self._on_scriptrunner_event(
|
||||
self._scriptrunner,
|
||||
ScriptRunnerEvent.SCRIPT_STARTED,
|
||||
page_script_hash="",
|
||||
)
|
||||
self._on_scriptrunner_event(
|
||||
self._scriptrunner, ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
|
||||
)
|
||||
|
||||
# Send an Exception message to the frontend.
|
||||
# Because _on_scriptrunner_event does its work in an eventloop callback,
|
||||
# this exception ForwardMsg *must* also be enqueued in a callback,
|
||||
# so that it will be enqueued *after* the various ForwardMsgs that
|
||||
# _on_scriptrunner_event sends.
|
||||
self._event_loop.call_soon_threadsafe(
|
||||
lambda: self._enqueue_forward_msg(self._create_exception_message(e))
|
||||
)
|
||||
|
||||
def request_rerun(self, client_state: ClientState | None) -> None:
|
||||
"""Signal that we're interested in running the script.
|
||||
|
||||
If the script is not already running, it will be started immediately.
|
||||
Otherwise, a rerun will be requested.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client_state : streamlit.proto.ClientState_pb2.ClientState | None
|
||||
The ClientState protobuf to run the script with, or None
|
||||
to use previous client state.
|
||||
|
||||
"""
|
||||
|
||||
if self._state == AppSessionState.SHUTDOWN_REQUESTED:
|
||||
_LOGGER.warning("Discarding rerun request after shutdown")
|
||||
return
|
||||
|
||||
if client_state:
|
||||
fragment_id = client_state.fragment_id
|
||||
|
||||
# Early check whether this fragment still exists in the fragment storage or
|
||||
# might have been removed by a full app run. This is not merely a
|
||||
# performance optimization, but also fixes following potential situation:
|
||||
# A fragment run might create a new ScriptRunner when the current
|
||||
# ScriptRunner is in state STOPPED (in this case, the 'success' variable
|
||||
# below is false and the new ScriptRunner is created). This will lead to all
|
||||
# events that were not sent / received from the previous script runner to be
|
||||
# ignored in _handle_scriptrunner_event_on_event_loop, because the
|
||||
# _script_runner changed. When the full app rerun ScriptRunner is done
|
||||
# (STOPPED) but its events are not processed before the new ScriptRunner is
|
||||
# created, its finished message is not sent to the frontend and no
|
||||
# full-app-run cleanup is happening. This scenario can be triggered by the
|
||||
# example app described in
|
||||
# https://github.com/streamlit/streamlit/issues/9921, where the dialog
|
||||
# sometimes stays open.
|
||||
if fragment_id and not self._fragment_storage.contains(fragment_id):
|
||||
_LOGGER.info(
|
||||
f"The fragment with id {fragment_id} does not exist anymore - "
|
||||
"it might have been removed during a preceding full-app rerun."
|
||||
)
|
||||
return
|
||||
|
||||
if client_state.HasField("context_info"):
|
||||
self._client_state.context_info.CopyFrom(client_state.context_info)
|
||||
|
||||
rerun_data = RerunData(
|
||||
client_state.query_string,
|
||||
client_state.widget_states,
|
||||
client_state.page_script_hash,
|
||||
client_state.page_name,
|
||||
fragment_id=fragment_id if fragment_id else None,
|
||||
is_auto_rerun=client_state.is_auto_rerun,
|
||||
context_info=client_state.context_info,
|
||||
)
|
||||
else:
|
||||
rerun_data = RerunData()
|
||||
|
||||
if self._scriptrunner is not None:
|
||||
if (
|
||||
bool(config.get_option("runner.fastReruns"))
|
||||
and not rerun_data.fragment_id
|
||||
):
|
||||
# If fastReruns is enabled and this is *not* a rerun of a fragment,
|
||||
# we don't send rerun requests to our existing ScriptRunner. Instead, we
|
||||
# tell it to shut down. We'll then spin up a new ScriptRunner, below, to
|
||||
# handle the rerun immediately.
|
||||
self._scriptrunner.request_stop()
|
||||
self._scriptrunner = None
|
||||
else:
|
||||
# Either fastReruns is not enabled or this RERUN request is a request to
|
||||
# run a fragment. We send our current ScriptRunner a rerun request, and
|
||||
# if it's accepted, we're done.
|
||||
success = self._scriptrunner.request_rerun(rerun_data)
|
||||
if success:
|
||||
return
|
||||
|
||||
# If we are here, then either we have no ScriptRunner, or our
|
||||
# current ScriptRunner is shutting down and cannot handle a rerun
|
||||
# request - so we'll create and start a new ScriptRunner.
|
||||
self._create_scriptrunner(rerun_data)
|
||||
|
||||
def request_script_stop(self) -> None:
|
||||
"""Request that the scriptrunner stop execution.
|
||||
|
||||
Does nothing if no scriptrunner exists.
|
||||
"""
|
||||
if self._scriptrunner is not None:
|
||||
self._scriptrunner.request_stop()
|
||||
|
||||
def clear_user_info(self) -> None:
|
||||
"""Clear the user info for this session."""
|
||||
self._user_info.clear()
|
||||
|
||||
def _create_scriptrunner(self, initial_rerun_data: RerunData) -> None:
|
||||
"""Create and run a new ScriptRunner with the given RerunData."""
|
||||
self._scriptrunner = ScriptRunner(
|
||||
session_id=self.id,
|
||||
main_script_path=self._script_data.main_script_path,
|
||||
session_state=self._session_state,
|
||||
uploaded_file_mgr=self._uploaded_file_mgr,
|
||||
script_cache=self._script_cache,
|
||||
initial_rerun_data=initial_rerun_data,
|
||||
user_info=self._user_info,
|
||||
fragment_storage=self._fragment_storage,
|
||||
pages_manager=self._pages_manager,
|
||||
)
|
||||
self._scriptrunner.on_event.connect(self._on_scriptrunner_event)
|
||||
self._scriptrunner.start()
|
||||
|
||||
@property
|
||||
def session_state(self) -> SessionState:
|
||||
return self._session_state
|
||||
|
||||
def _should_rerun_on_file_change(self, filepath: str) -> bool:
|
||||
pages = self._pages_manager.get_pages()
|
||||
|
||||
changed_page_script_hash = next(
|
||||
filter(lambda k: pages[k]["script_path"] == filepath, pages),
|
||||
None,
|
||||
)
|
||||
|
||||
if changed_page_script_hash is not None:
|
||||
current_page_script_hash = self._client_state.page_script_hash
|
||||
return changed_page_script_hash == current_page_script_hash
|
||||
|
||||
return True
|
||||
|
||||
def _on_source_file_changed(self, filepath: str | None = None) -> None:
|
||||
"""One of our source files changed. Clear the cache and schedule a rerun if
|
||||
appropriate.
|
||||
"""
|
||||
self._script_cache.clear()
|
||||
|
||||
if filepath is not None and not self._should_rerun_on_file_change(filepath):
|
||||
return
|
||||
|
||||
if self._run_on_save:
|
||||
self.request_rerun(self._client_state)
|
||||
else:
|
||||
self._enqueue_forward_msg(self._create_file_change_message())
|
||||
|
||||
def _on_secrets_file_changed(self, _) -> None:
|
||||
"""Called when `secrets.file_change_listener` emits a Signal."""
|
||||
|
||||
# NOTE: At the time of writing, this function only calls
|
||||
# `_on_source_file_changed`. The reason behind creating this function instead of
|
||||
# just passing `_on_source_file_changed` to `connect` / `disconnect` directly is
|
||||
# that every function that is passed to `connect` / `disconnect` must have at
|
||||
# least one argument for `sender` (in this case we don't really care about it,
|
||||
# thus `_`), and introducing an unnecessary argument to
|
||||
# `_on_source_file_changed` just for this purpose sounded finicky.
|
||||
self._on_source_file_changed()
|
||||
|
||||
def _clear_queue(self, fragment_ids_this_run: list[str] | None = None) -> None:
|
||||
self._browser_queue.clear(
|
||||
retain_lifecycle_msgs=True, fragment_ids_this_run=fragment_ids_this_run
|
||||
)
|
||||
|
||||
def _on_scriptrunner_event(
|
||||
self,
|
||||
sender: ScriptRunner | None,
|
||||
event: ScriptRunnerEvent,
|
||||
forward_msg: ForwardMsg | None = None,
|
||||
exception: BaseException | None = None,
|
||||
client_state: ClientState | None = None,
|
||||
page_script_hash: str | None = None,
|
||||
fragment_ids_this_run: list[str] | None = None,
|
||||
pages: dict[PageHash, PageInfo] | None = None,
|
||||
) -> None:
|
||||
"""Called when our ScriptRunner emits an event.
|
||||
|
||||
This is generally called from the sender ScriptRunner's script thread.
|
||||
We forward the event on to _handle_scriptrunner_event_on_event_loop,
|
||||
which will be called on the main thread.
|
||||
"""
|
||||
self._event_loop.call_soon_threadsafe(
|
||||
lambda: self._handle_scriptrunner_event_on_event_loop(
|
||||
sender,
|
||||
event,
|
||||
forward_msg,
|
||||
exception,
|
||||
client_state,
|
||||
page_script_hash,
|
||||
fragment_ids_this_run,
|
||||
pages,
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_scriptrunner_event_on_event_loop(
|
||||
self,
|
||||
sender: ScriptRunner | None,
|
||||
event: ScriptRunnerEvent,
|
||||
forward_msg: ForwardMsg | None = None,
|
||||
exception: BaseException | None = None,
|
||||
client_state: ClientState | None = None,
|
||||
page_script_hash: str | None = None,
|
||||
fragment_ids_this_run: list[str] | None = None,
|
||||
pages: dict[PageHash, PageInfo] | None = None,
|
||||
) -> None:
|
||||
"""Handle a ScriptRunner event.
|
||||
|
||||
This function must only be called on our eventloop thread.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sender : ScriptRunner | None
|
||||
The ScriptRunner that emitted the event. (This may be set to
|
||||
None when called from `handle_backmsg_exception`, if no
|
||||
ScriptRunner was active when the backmsg exception was raised.)
|
||||
|
||||
event : ScriptRunnerEvent
|
||||
The event type.
|
||||
|
||||
forward_msg : ForwardMsg | None
|
||||
The ForwardMsg to send to the frontend. Set only for the
|
||||
ENQUEUE_FORWARD_MSG event.
|
||||
|
||||
exception : BaseException | None
|
||||
An exception thrown during compilation. Set only for the
|
||||
SCRIPT_STOPPED_WITH_COMPILE_ERROR event.
|
||||
|
||||
client_state : streamlit.proto.ClientState_pb2.ClientState | None
|
||||
The ScriptRunner's final ClientState. Set only for the
|
||||
SHUTDOWN event.
|
||||
|
||||
page_script_hash : str | None
|
||||
A hash of the script path corresponding to the page currently being
|
||||
run. Set only for the SCRIPT_STARTED event.
|
||||
|
||||
fragment_ids_this_run : list[str] | None
|
||||
The fragment IDs of the fragments being executed in this script run. Only
|
||||
set for the SCRIPT_STARTED event. If this value is falsy, this script run
|
||||
must be for the full script.
|
||||
|
||||
clear_forward_msg_queue : bool
|
||||
If set (the default), clears the queue of forward messages to be sent to the
|
||||
browser. Set only for the SCRIPT_STARTED event.
|
||||
"""
|
||||
|
||||
assert self._event_loop == asyncio.get_running_loop(), (
|
||||
"This function must only be called on the eventloop thread the AppSession was created on."
|
||||
)
|
||||
|
||||
if sender is not self._scriptrunner:
|
||||
# This event was sent by a non-current ScriptRunner; ignore it.
|
||||
# This can happen after sppinng up a new ScriptRunner (to handle a
|
||||
# rerun request, for example) while another ScriptRunner is still
|
||||
# shutting down. The shutting-down ScriptRunner may still
|
||||
# emit events.
|
||||
_LOGGER.debug("Ignoring event from non-current ScriptRunner: %s", event)
|
||||
return
|
||||
|
||||
prev_state = self._state
|
||||
|
||||
if event == ScriptRunnerEvent.SCRIPT_STARTED:
|
||||
if self._state != AppSessionState.SHUTDOWN_REQUESTED:
|
||||
self._state = AppSessionState.APP_IS_RUNNING
|
||||
assert page_script_hash is not None, (
|
||||
"page_script_hash must be set for the SCRIPT_STARTED event"
|
||||
)
|
||||
|
||||
# Update the client state with the new page_script_hash if
|
||||
# necessary. This handles an edge case where a script is never
|
||||
# finishes (eg. by calling st.rerun()), but the page has changed
|
||||
# via st.navigation()
|
||||
if page_script_hash != self._client_state.page_script_hash:
|
||||
self._client_state.page_script_hash = page_script_hash
|
||||
|
||||
self._clear_queue(fragment_ids_this_run)
|
||||
|
||||
self._enqueue_forward_msg(
|
||||
self._create_new_session_message(
|
||||
page_script_hash, fragment_ids_this_run, pages
|
||||
)
|
||||
)
|
||||
|
||||
elif (
|
||||
event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
|
||||
or event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_COMPILE_ERROR
|
||||
or event == ScriptRunnerEvent.FRAGMENT_STOPPED_WITH_SUCCESS
|
||||
):
|
||||
if self._state != AppSessionState.SHUTDOWN_REQUESTED:
|
||||
self._state = AppSessionState.APP_NOT_RUNNING
|
||||
|
||||
if event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS:
|
||||
status = ForwardMsg.FINISHED_SUCCESSFULLY
|
||||
elif event == ScriptRunnerEvent.FRAGMENT_STOPPED_WITH_SUCCESS:
|
||||
status = ForwardMsg.FINISHED_FRAGMENT_RUN_SUCCESSFULLY
|
||||
else:
|
||||
status = ForwardMsg.FINISHED_WITH_COMPILE_ERROR
|
||||
|
||||
self._enqueue_forward_msg(self._create_script_finished_message(status))
|
||||
self._debug_last_backmsg_id = None
|
||||
|
||||
if (
|
||||
event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
|
||||
or event == ScriptRunnerEvent.FRAGMENT_STOPPED_WITH_SUCCESS
|
||||
):
|
||||
# The script completed successfully: update our
|
||||
# LocalSourcesWatcher to account for any source code changes
|
||||
# that change which modules should be watched.
|
||||
if self._local_sources_watcher:
|
||||
self._local_sources_watcher.update_watched_modules()
|
||||
self._local_sources_watcher.update_watched_pages()
|
||||
else:
|
||||
# The script didn't complete successfully: send the exception
|
||||
# to the frontend.
|
||||
assert exception is not None, (
|
||||
"exception must be set for the SCRIPT_STOPPED_WITH_COMPILE_ERROR event"
|
||||
)
|
||||
msg = ForwardMsg()
|
||||
exception_utils.marshall(
|
||||
msg.session_event.script_compilation_exception, exception
|
||||
)
|
||||
self._enqueue_forward_msg(msg)
|
||||
|
||||
elif event == ScriptRunnerEvent.SCRIPT_STOPPED_FOR_RERUN:
|
||||
self._state = AppSessionState.APP_NOT_RUNNING
|
||||
self._enqueue_forward_msg(
|
||||
self._create_script_finished_message(
|
||||
ForwardMsg.FINISHED_EARLY_FOR_RERUN
|
||||
)
|
||||
)
|
||||
if self._local_sources_watcher:
|
||||
self._local_sources_watcher.update_watched_modules()
|
||||
|
||||
elif event == ScriptRunnerEvent.SHUTDOWN:
|
||||
assert client_state is not None, (
|
||||
"client_state must be set for the SHUTDOWN event"
|
||||
)
|
||||
|
||||
if self._state == AppSessionState.SHUTDOWN_REQUESTED:
|
||||
# Only clear media files if the script is done running AND the
|
||||
# session is actually shutting down.
|
||||
runtime.get_instance().media_file_mgr.clear_session_refs(self.id)
|
||||
|
||||
self._client_state = client_state
|
||||
self._scriptrunner = None
|
||||
|
||||
elif event == ScriptRunnerEvent.ENQUEUE_FORWARD_MSG:
|
||||
assert forward_msg is not None, (
|
||||
"null forward_msg in ENQUEUE_FORWARD_MSG event"
|
||||
)
|
||||
self._enqueue_forward_msg(forward_msg)
|
||||
|
||||
# Send a message if our run state changed
|
||||
app_was_running = prev_state == AppSessionState.APP_IS_RUNNING
|
||||
app_is_running = self._state == AppSessionState.APP_IS_RUNNING
|
||||
if app_is_running != app_was_running:
|
||||
self._enqueue_forward_msg(self._create_session_status_changed_message())
|
||||
|
||||
def _create_session_status_changed_message(self) -> ForwardMsg:
|
||||
"""Create and return a session_status_changed ForwardMsg."""
|
||||
msg = ForwardMsg()
|
||||
msg.session_status_changed.run_on_save = self._run_on_save
|
||||
msg.session_status_changed.script_is_running = (
|
||||
self._state == AppSessionState.APP_IS_RUNNING
|
||||
)
|
||||
return msg
|
||||
|
||||
def _create_file_change_message(self) -> ForwardMsg:
|
||||
"""Create and return a 'script_changed_on_disk' ForwardMsg."""
|
||||
msg = ForwardMsg()
|
||||
msg.session_event.script_changed_on_disk = True
|
||||
return msg
|
||||
|
||||
def _create_new_session_message(
|
||||
self,
|
||||
page_script_hash: str,
|
||||
fragment_ids_this_run: list[str] | None = None,
|
||||
pages: dict[PageHash, PageInfo] | None = None,
|
||||
) -> ForwardMsg:
|
||||
"""Create and return a new_session ForwardMsg."""
|
||||
msg = ForwardMsg()
|
||||
|
||||
msg.new_session.script_run_id = _generate_scriptrun_id()
|
||||
msg.new_session.name = self._script_data.name
|
||||
msg.new_session.main_script_path = self._pages_manager.main_script_path
|
||||
msg.new_session.main_script_hash = self._pages_manager.main_script_hash
|
||||
msg.new_session.page_script_hash = page_script_hash
|
||||
|
||||
if fragment_ids_this_run:
|
||||
msg.new_session.fragment_ids_this_run.extend(fragment_ids_this_run)
|
||||
|
||||
self._populate_app_pages(
|
||||
msg.new_session, pages or self._pages_manager.get_pages()
|
||||
)
|
||||
_populate_config_msg(msg.new_session.config)
|
||||
_populate_theme_msg(msg.new_session.custom_theme)
|
||||
_populate_theme_msg(
|
||||
msg.new_session.custom_theme.sidebar,
|
||||
f"theme.{config.CustomThemeCategories.SIDEBAR.value}",
|
||||
)
|
||||
|
||||
# Immutable session data. We send this every time a new session is
|
||||
# started, to avoid having to track whether the client has already
|
||||
# received it. It does not change from run to run; it's up to the
|
||||
# to perform one-time initialization only once.
|
||||
imsg = msg.new_session.initialize
|
||||
|
||||
_populate_user_info_msg(imsg.user_info)
|
||||
|
||||
imsg.environment_info.streamlit_version = STREAMLIT_VERSION_STRING
|
||||
imsg.environment_info.python_version = ".".join(map(str, sys.version_info))
|
||||
|
||||
imsg.session_status.run_on_save = self._run_on_save
|
||||
imsg.session_status.script_is_running = (
|
||||
self._state == AppSessionState.APP_IS_RUNNING
|
||||
)
|
||||
|
||||
imsg.is_hello = self._script_data.is_hello
|
||||
imsg.session_id = self.id
|
||||
|
||||
return msg
|
||||
|
||||
def _create_script_finished_message(
|
||||
self, status: ForwardMsg.ScriptFinishedStatus.ValueType
|
||||
) -> ForwardMsg:
|
||||
"""Create and return a script_finished ForwardMsg."""
|
||||
msg = ForwardMsg()
|
||||
msg.script_finished = status
|
||||
return msg
|
||||
|
||||
def _create_exception_message(self, e: BaseException) -> ForwardMsg:
|
||||
"""Create and return an Exception ForwardMsg."""
|
||||
msg = ForwardMsg()
|
||||
exception_utils.marshall(msg.delta.new_element.exception, e)
|
||||
return msg
|
||||
|
||||
def _handle_git_information_request(self) -> None:
|
||||
msg = ForwardMsg()
|
||||
|
||||
try:
|
||||
from streamlit.git_util import GitRepo
|
||||
|
||||
repo = GitRepo(self._script_data.main_script_path)
|
||||
|
||||
repo_info = repo.get_repo_info()
|
||||
if repo_info is None:
|
||||
return
|
||||
|
||||
repository_name, branch, module = repo_info
|
||||
|
||||
repository_name = repository_name.removesuffix(".git")
|
||||
|
||||
msg.git_info_changed.repository = repository_name
|
||||
msg.git_info_changed.branch = branch
|
||||
msg.git_info_changed.module = module
|
||||
|
||||
msg.git_info_changed.untracked_files[:] = repo.untracked_files
|
||||
msg.git_info_changed.uncommitted_files[:] = repo.uncommitted_files
|
||||
|
||||
if repo.is_head_detached:
|
||||
msg.git_info_changed.state = GitInfo.GitStates.HEAD_DETACHED
|
||||
elif len(repo.ahead_commits) > 0:
|
||||
msg.git_info_changed.state = GitInfo.GitStates.AHEAD_OF_REMOTE
|
||||
else:
|
||||
msg.git_info_changed.state = GitInfo.GitStates.DEFAULT
|
||||
|
||||
self._enqueue_forward_msg(msg)
|
||||
except Exception as ex:
|
||||
# Users may never even install Git in the first place, so this
|
||||
# error requires no action. It can be useful for debugging.
|
||||
_LOGGER.debug("Obtaining Git information produced an error", exc_info=ex)
|
||||
|
||||
def _handle_rerun_script_request(
|
||||
self, client_state: ClientState | None = None
|
||||
) -> None:
|
||||
"""Tell the ScriptRunner to re-run its script.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client_state : streamlit.proto.ClientState_pb2.ClientState | None
|
||||
The ClientState protobuf to run the script with, or None
|
||||
to use previous client state.
|
||||
|
||||
"""
|
||||
self.request_rerun(client_state)
|
||||
|
||||
def _handle_stop_script_request(self) -> None:
|
||||
"""Tell the ScriptRunner to stop running its script."""
|
||||
self.request_script_stop()
|
||||
|
||||
def _handle_clear_cache_request(self) -> None:
|
||||
"""Clear this app's cache.
|
||||
|
||||
Because this cache is global, it will be cleared for all users.
|
||||
|
||||
"""
|
||||
caching.cache_data.clear()
|
||||
caching.cache_resource.clear()
|
||||
self._session_state.clear()
|
||||
|
||||
def _handle_app_heartbeat_request(self) -> None:
|
||||
"""Handle an incoming app heartbeat.
|
||||
|
||||
The heartbeat indicates the frontend is active and keeps the
|
||||
websocket from going idle and disconnecting.
|
||||
|
||||
The actual handler here is a noop
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
def _handle_set_run_on_save_request(self, new_value: bool) -> None:
|
||||
"""Change our run_on_save flag to the given value.
|
||||
|
||||
The browser will be notified of the change.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
new_value : bool
|
||||
New run_on_save value
|
||||
|
||||
"""
|
||||
self._run_on_save = new_value
|
||||
self._enqueue_forward_msg(self._create_session_status_changed_message())
|
||||
|
||||
def _handle_file_urls_request(self, file_urls_request: FileURLsRequest) -> None:
|
||||
"""Handle a file_urls_request BackMsg sent by the client."""
|
||||
msg = ForwardMsg()
|
||||
msg.file_urls_response.response_id = file_urls_request.request_id
|
||||
|
||||
upload_url_infos = self._uploaded_file_mgr.get_upload_urls(
|
||||
self.id, file_urls_request.file_names
|
||||
)
|
||||
|
||||
for upload_url_info in upload_url_infos:
|
||||
msg.file_urls_response.file_urls.append(
|
||||
FileURLs(
|
||||
file_id=upload_url_info.file_id,
|
||||
upload_url=upload_url_info.upload_url,
|
||||
delete_url=upload_url_info.delete_url,
|
||||
)
|
||||
)
|
||||
|
||||
self._enqueue_forward_msg(msg)
|
||||
|
||||
def _populate_app_pages(
|
||||
self, msg: NewSession, pages: dict[PageHash, PageInfo]
|
||||
) -> None:
|
||||
for page_script_hash, page_info in pages.items():
|
||||
page_proto = msg.app_pages.add()
|
||||
|
||||
page_proto.page_script_hash = page_script_hash
|
||||
page_proto.page_name = page_info["page_name"].replace("_", " ")
|
||||
page_proto.url_pathname = page_info["page_name"]
|
||||
page_proto.icon = page_info["icon"]
|
||||
|
||||
|
||||
# Config.ToolbarMode.ValueType does not exist at runtime (only in the pyi stubs), so
|
||||
# we need to use quotes.
|
||||
# This field will be available at runtime as of protobuf 3.20.1, but
|
||||
# we are using an older version.
|
||||
# For details, see: https://github.com/protocolbuffers/protobuf/issues/8175
|
||||
def _get_toolbar_mode() -> Config.ToolbarMode.ValueType:
|
||||
config_key = "client.toolbarMode"
|
||||
config_value = config.get_option(config_key)
|
||||
enum_value: Config.ToolbarMode.ValueType | None = getattr(
|
||||
Config.ToolbarMode, config_value.upper()
|
||||
)
|
||||
if enum_value is None:
|
||||
allowed_values = ", ".join(k.lower() for k in Config.ToolbarMode.keys())
|
||||
raise ValueError(
|
||||
f"Config {config_key!r} expects to have one of "
|
||||
f"the following values: {allowed_values}. "
|
||||
f"Current value: {config_value}"
|
||||
)
|
||||
return enum_value
|
||||
|
||||
|
||||
def _populate_config_msg(msg: Config) -> None:
|
||||
msg.gather_usage_stats = config.get_option("browser.gatherUsageStats")
|
||||
msg.max_cached_message_age = config.get_option("global.maxCachedMessageAge")
|
||||
msg.allow_run_on_save = config.get_option("server.allowRunOnSave")
|
||||
msg.hide_top_bar = config.get_option("ui.hideTopBar")
|
||||
if config.get_option("client.showSidebarNavigation") is False:
|
||||
msg.hide_sidebar_nav = True
|
||||
msg.toolbar_mode = _get_toolbar_mode()
|
||||
|
||||
|
||||
def _populate_theme_msg(msg: CustomThemeConfig, section: str = "theme") -> None:
|
||||
theme_opts = config.get_options_for_section(section)
|
||||
if all(val is None for val in theme_opts.values()):
|
||||
return
|
||||
|
||||
for option_name, option_val in theme_opts.items():
|
||||
# We need to ignore some config options here that need special handling
|
||||
# and cannot directly be set on the protobuf.
|
||||
if option_name not in {"base", "font", "fontFaces"} and option_val is not None:
|
||||
setattr(msg, to_snake_case(option_name), option_val)
|
||||
|
||||
# NOTE: If unset, base and font will default to the protobuf enum zero
|
||||
# values, which are BaseTheme.LIGHT and FontFamily.SANS_SERIF,
|
||||
# respectively. This is why we both don't handle the cases explicitly and
|
||||
# also only log a warning when receiving invalid base/font options.
|
||||
base_map = {
|
||||
"light": msg.BaseTheme.LIGHT,
|
||||
"dark": msg.BaseTheme.DARK,
|
||||
}
|
||||
base = theme_opts.get("base", None)
|
||||
if base is not None:
|
||||
if base not in base_map:
|
||||
_LOGGER.warning(
|
||||
f'"{base}" is an invalid value for theme.base.'
|
||||
f" Allowed values include {list(base_map.keys())}."
|
||||
' Setting theme.base to "light".'
|
||||
)
|
||||
else:
|
||||
msg.base = base_map[base]
|
||||
|
||||
# Since the font field uses the deprecated enum, we need to put the font
|
||||
# config into the body_font field instead:
|
||||
body_font = theme_opts.get("font", None)
|
||||
if body_font:
|
||||
msg.body_font = body_font
|
||||
|
||||
font_faces = theme_opts.get("fontFaces", None)
|
||||
# If fontFaces was configured via config.toml, it's already a parsed list of
|
||||
# dictionaries. However, if it was provided via env variable or via CLI arg,
|
||||
# it's a json string that still needs to be parsed.
|
||||
if isinstance(font_faces, str):
|
||||
try:
|
||||
font_faces = json.loads(font_faces)
|
||||
except Exception as e:
|
||||
_LOGGER.warning(
|
||||
"Failed to parse the theme.fontFaces config option with json.loads: "
|
||||
f"{font_faces}.",
|
||||
exc_info=e,
|
||||
)
|
||||
font_faces = None
|
||||
|
||||
if font_faces is not None:
|
||||
for font_face in font_faces:
|
||||
try:
|
||||
msg.font_faces.append(ParseDict(font_face, FontFace()))
|
||||
except Exception as e:
|
||||
_LOGGER.warning(
|
||||
f"Failed to parse the theme.fontFaces config option: {font_face}.",
|
||||
exc_info=e,
|
||||
)
|
||||
|
||||
|
||||
def _populate_user_info_msg(msg: UserInfo) -> None:
|
||||
msg.installation_id = Installation.instance().installation_id
|
||||
msg.installation_id_v3 = Installation.instance().installation_id_v3
|
||||
@@ -0,0 +1,98 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from streamlit.runtime.caching.cache_data_api import (
|
||||
CACHE_DATA_MESSAGE_REPLAY_CTX,
|
||||
CacheDataAPI,
|
||||
get_data_cache_stats_provider,
|
||||
)
|
||||
from streamlit.runtime.caching.cache_errors import CACHE_DOCS_URL
|
||||
from streamlit.runtime.caching.cache_resource_api import (
|
||||
CACHE_RESOURCE_MESSAGE_REPLAY_CTX,
|
||||
CacheResourceAPI,
|
||||
get_resource_cache_stats_provider,
|
||||
)
|
||||
from streamlit.runtime.caching.legacy_cache_api import cache as _cache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.protobuf.message import Message
|
||||
|
||||
from streamlit.proto.Block_pb2 import Block
|
||||
|
||||
|
||||
def save_element_message(
|
||||
delta_type: str,
|
||||
element_proto: Message,
|
||||
invoked_dg_id: str,
|
||||
used_dg_id: str,
|
||||
returned_dg_id: str,
|
||||
) -> None:
|
||||
"""Save the message for an element to a thread-local callstack, so it can
|
||||
be used later to replay the element when a cache-decorated function's
|
||||
execution is skipped.
|
||||
"""
|
||||
CACHE_DATA_MESSAGE_REPLAY_CTX.save_element_message(
|
||||
delta_type, element_proto, invoked_dg_id, used_dg_id, returned_dg_id
|
||||
)
|
||||
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_element_message(
|
||||
delta_type, element_proto, invoked_dg_id, used_dg_id, returned_dg_id
|
||||
)
|
||||
|
||||
|
||||
def save_block_message(
|
||||
block_proto: Block,
|
||||
invoked_dg_id: str,
|
||||
used_dg_id: str,
|
||||
returned_dg_id: str,
|
||||
) -> None:
|
||||
"""Save the message for a block to a thread-local callstack, so it can
|
||||
be used later to replay the block when a cache-decorated function's
|
||||
execution is skipped.
|
||||
"""
|
||||
CACHE_DATA_MESSAGE_REPLAY_CTX.save_block_message(
|
||||
block_proto, invoked_dg_id, used_dg_id, returned_dg_id
|
||||
)
|
||||
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_block_message(
|
||||
block_proto, invoked_dg_id, used_dg_id, returned_dg_id
|
||||
)
|
||||
|
||||
|
||||
def save_media_data(image_data: bytes | str, mimetype: str, image_id: str) -> None:
|
||||
CACHE_DATA_MESSAGE_REPLAY_CTX.save_image_data(image_data, mimetype, image_id)
|
||||
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_image_data(image_data, mimetype, image_id)
|
||||
|
||||
|
||||
# Create and export public API singletons.
|
||||
cache_data = CacheDataAPI(decorator_metric_name="cache_data")
|
||||
cache_resource = CacheResourceAPI(decorator_metric_name="cache_resource")
|
||||
# TODO(lukasmasuch): This is the legacy cache API name which is deprecated
|
||||
# and it should be removed in the future.
|
||||
cache = _cache
|
||||
|
||||
|
||||
__all__ = [
|
||||
"cache",
|
||||
"CACHE_DOCS_URL",
|
||||
"save_element_message",
|
||||
"save_block_message",
|
||||
"save_media_data",
|
||||
"get_data_cache_stats_provider",
|
||||
"get_resource_cache_stats_provider",
|
||||
"cache_data",
|
||||
"cache_resource",
|
||||
]
|
||||
@@ -0,0 +1,665 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""@st.cache_data: pickle-based caching."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
import threading
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Final,
|
||||
Literal,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import streamlit as st
|
||||
from streamlit import runtime
|
||||
from streamlit.errors import StreamlitAPIException
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.caching.cache_errors import CacheError, CacheKeyNotFoundError
|
||||
from streamlit.runtime.caching.cache_type import CacheType
|
||||
from streamlit.runtime.caching.cache_utils import (
|
||||
Cache,
|
||||
CachedFuncInfo,
|
||||
make_cached_func_wrapper,
|
||||
)
|
||||
from streamlit.runtime.caching.cached_message_replay import (
|
||||
CachedMessageReplayContext,
|
||||
CachedResult,
|
||||
MsgData,
|
||||
show_widget_replay_deprecation,
|
||||
)
|
||||
from streamlit.runtime.caching.storage import (
|
||||
CacheStorage,
|
||||
CacheStorageContext,
|
||||
CacheStorageError,
|
||||
CacheStorageKeyNotFoundError,
|
||||
CacheStorageManager,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
InvalidCacheStorageContext,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.dummy_cache_storage import (
|
||||
MemoryCacheStorageManager,
|
||||
)
|
||||
from streamlit.runtime.metrics_util import gather_metrics
|
||||
from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats
|
||||
from streamlit.time_util import time_to_seconds
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import types
|
||||
from datetime import timedelta
|
||||
|
||||
from streamlit.runtime.caching.hashing import HashFuncsDict
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
CACHE_DATA_MESSAGE_REPLAY_CTX = CachedMessageReplayContext(CacheType.DATA)
|
||||
|
||||
# The cache persistence options we support: "disk" or None
|
||||
CachePersistType: TypeAlias = Union[Literal["disk"], None]
|
||||
|
||||
|
||||
class CachedDataFuncInfo(CachedFuncInfo):
|
||||
"""Implements the CachedFuncInfo interface for @st.cache_data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: types.FunctionType,
|
||||
show_spinner: bool | str,
|
||||
persist: CachePersistType,
|
||||
max_entries: int | None,
|
||||
ttl: float | timedelta | str | None,
|
||||
hash_funcs: HashFuncsDict | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
func,
|
||||
show_spinner=show_spinner,
|
||||
hash_funcs=hash_funcs,
|
||||
)
|
||||
self.persist = persist
|
||||
self.max_entries = max_entries
|
||||
self.ttl = ttl
|
||||
|
||||
self.validate_params()
|
||||
|
||||
@property
|
||||
def cache_type(self) -> CacheType:
|
||||
return CacheType.DATA
|
||||
|
||||
@property
|
||||
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
|
||||
return CACHE_DATA_MESSAGE_REPLAY_CTX
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""A human-readable name for the cached function."""
|
||||
return f"{self.func.__module__}.{self.func.__qualname__}"
|
||||
|
||||
def get_function_cache(self, function_key: str) -> Cache:
|
||||
return _data_caches.get_cache(
|
||||
key=function_key,
|
||||
persist=self.persist,
|
||||
max_entries=self.max_entries,
|
||||
ttl=self.ttl,
|
||||
display_name=self.display_name,
|
||||
)
|
||||
|
||||
def validate_params(self) -> None:
|
||||
"""
|
||||
Validate the params passed to @st.cache_data are compatible with cache storage.
|
||||
|
||||
When called, this method could log warnings if cache params are invalid
|
||||
for current storage.
|
||||
"""
|
||||
_data_caches.validate_cache_params(
|
||||
function_name=self.func.__name__,
|
||||
persist=self.persist,
|
||||
max_entries=self.max_entries,
|
||||
ttl=self.ttl,
|
||||
)
|
||||
|
||||
|
||||
class DataCaches(CacheStatsProvider):
|
||||
"""Manages all DataCache instances."""
|
||||
|
||||
def __init__(self):
|
||||
self._caches_lock = threading.Lock()
|
||||
self._function_caches: dict[str, DataCache] = {}
|
||||
|
||||
def get_cache(
|
||||
self,
|
||||
key: str,
|
||||
persist: CachePersistType,
|
||||
max_entries: int | None,
|
||||
ttl: int | float | timedelta | str | None,
|
||||
display_name: str,
|
||||
) -> DataCache:
|
||||
"""Return the mem cache for the given key.
|
||||
|
||||
If it doesn't exist, create a new one with the given params.
|
||||
"""
|
||||
|
||||
ttl_seconds = time_to_seconds(ttl, coerce_none_to_inf=False)
|
||||
|
||||
# Get the existing cache, if it exists, and validate that its params
|
||||
# haven't changed.
|
||||
with self._caches_lock:
|
||||
cache = self._function_caches.get(key)
|
||||
if (
|
||||
cache is not None
|
||||
and cache.ttl_seconds == ttl_seconds
|
||||
and cache.max_entries == max_entries
|
||||
and cache.persist == persist
|
||||
):
|
||||
return cache
|
||||
|
||||
# Close the existing cache's storage, if it exists.
|
||||
if cache is not None:
|
||||
_LOGGER.debug(
|
||||
"Closing existing DataCache storage "
|
||||
"(key=%s, persist=%s, max_entries=%s, ttl=%s) "
|
||||
"before creating new one with different params",
|
||||
key,
|
||||
persist,
|
||||
max_entries,
|
||||
ttl,
|
||||
)
|
||||
cache.storage.close()
|
||||
|
||||
# Create a new cache object and put it in our dict
|
||||
_LOGGER.debug(
|
||||
"Creating new DataCache (key=%s, persist=%s, max_entries=%s, ttl=%s)",
|
||||
key,
|
||||
persist,
|
||||
max_entries,
|
||||
ttl,
|
||||
)
|
||||
|
||||
cache_context = self.create_cache_storage_context(
|
||||
function_key=key,
|
||||
function_name=display_name,
|
||||
ttl_seconds=ttl_seconds,
|
||||
max_entries=max_entries,
|
||||
persist=persist,
|
||||
)
|
||||
cache_storage_manager = self.get_storage_manager()
|
||||
storage = cache_storage_manager.create(cache_context)
|
||||
|
||||
cache = DataCache(
|
||||
key=key,
|
||||
storage=storage,
|
||||
persist=persist,
|
||||
max_entries=max_entries,
|
||||
ttl_seconds=ttl_seconds,
|
||||
display_name=display_name,
|
||||
)
|
||||
self._function_caches[key] = cache
|
||||
return cache
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Clear all in-memory and on-disk caches."""
|
||||
with self._caches_lock:
|
||||
try:
|
||||
# try to remove in optimal way if such ability provided by
|
||||
# storage manager clear_all method;
|
||||
# if not implemented, fallback to remove all
|
||||
# available storages one by one
|
||||
self.get_storage_manager().clear_all()
|
||||
except NotImplementedError:
|
||||
for data_cache in self._function_caches.values():
|
||||
data_cache.clear()
|
||||
data_cache.storage.close()
|
||||
self._function_caches = {}
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
with self._caches_lock:
|
||||
# Shallow-clone our caches. We don't want to hold the global
|
||||
# lock during stats-gathering.
|
||||
function_caches = self._function_caches.copy()
|
||||
|
||||
stats: list[CacheStat] = []
|
||||
for cache in function_caches.values():
|
||||
stats.extend(cache.get_stats())
|
||||
return group_stats(stats)
|
||||
|
||||
def validate_cache_params(
|
||||
self,
|
||||
function_name: str,
|
||||
persist: CachePersistType,
|
||||
max_entries: int | None,
|
||||
ttl: int | float | timedelta | str | None,
|
||||
) -> None:
|
||||
"""Validate that the cache params are valid for given storage.
|
||||
|
||||
Raises
|
||||
------
|
||||
InvalidCacheStorageContext
|
||||
Raised if the cache storage manager is not able to work with provided
|
||||
CacheStorageContext.
|
||||
"""
|
||||
|
||||
ttl_seconds = time_to_seconds(ttl, coerce_none_to_inf=False)
|
||||
|
||||
cache_context = self.create_cache_storage_context(
|
||||
function_key="DUMMY_KEY",
|
||||
function_name=function_name,
|
||||
ttl_seconds=ttl_seconds,
|
||||
max_entries=max_entries,
|
||||
persist=persist,
|
||||
)
|
||||
try:
|
||||
self.get_storage_manager().check_context(cache_context)
|
||||
except InvalidCacheStorageContext as e:
|
||||
_LOGGER.error(
|
||||
"Cache params for function %s are incompatible with current "
|
||||
"cache storage manager.",
|
||||
function_name,
|
||||
exc_info=e,
|
||||
)
|
||||
raise
|
||||
|
||||
def create_cache_storage_context(
|
||||
self,
|
||||
function_key: str,
|
||||
function_name: str,
|
||||
persist: CachePersistType,
|
||||
ttl_seconds: float | None,
|
||||
max_entries: int | None,
|
||||
) -> CacheStorageContext:
|
||||
return CacheStorageContext(
|
||||
function_key=function_key,
|
||||
function_display_name=function_name,
|
||||
ttl_seconds=ttl_seconds,
|
||||
max_entries=max_entries,
|
||||
persist=persist,
|
||||
)
|
||||
|
||||
def get_storage_manager(self) -> CacheStorageManager:
|
||||
if runtime.exists():
|
||||
return runtime.get_instance().cache_storage_manager
|
||||
else:
|
||||
# When running in "raw mode", we can't access the CacheStorageManager,
|
||||
# so we're falling back to InMemoryCache.
|
||||
_LOGGER.warning("No runtime found, using MemoryCacheStorageManager")
|
||||
return MemoryCacheStorageManager()
|
||||
|
||||
|
||||
# Singleton DataCaches instance
|
||||
_data_caches = DataCaches()
|
||||
|
||||
|
||||
def get_data_cache_stats_provider() -> CacheStatsProvider:
|
||||
"""Return the StatsProvider for all @st.cache_data functions."""
|
||||
return _data_caches
|
||||
|
||||
|
||||
class CacheDataAPI:
|
||||
"""Implements the public st.cache_data API: the @st.cache_data decorator, and
|
||||
st.cache_data.clear().
|
||||
"""
|
||||
|
||||
def __init__(self, decorator_metric_name: str):
|
||||
"""Create a CacheDataAPI instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
decorator_metric_name
|
||||
The metric name to record for decorator usage.
|
||||
"""
|
||||
|
||||
# Parameterize the decorator metric name.
|
||||
# (Ignore spurious mypy complaints - https://github.com/python/mypy/issues/2427)
|
||||
self._decorator = gather_metrics( # type: ignore
|
||||
decorator_metric_name, self._decorator
|
||||
)
|
||||
|
||||
# Type-annotate the decorator function.
|
||||
# (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories)
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
# Bare decorator usage
|
||||
@overload
|
||||
def __call__(self, func: F) -> F: ...
|
||||
|
||||
# Decorator with arguments
|
||||
@overload
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
ttl: float | timedelta | str | None = None,
|
||||
max_entries: int | None = None,
|
||||
show_spinner: bool | str = True,
|
||||
persist: CachePersistType | bool = None,
|
||||
experimental_allow_widgets: bool = False,
|
||||
hash_funcs: HashFuncsDict | None = None,
|
||||
) -> Callable[[F], F]: ...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
func: F | None = None,
|
||||
*,
|
||||
ttl: float | timedelta | str | None = None,
|
||||
max_entries: int | None = None,
|
||||
show_spinner: bool | str = True,
|
||||
persist: CachePersistType | bool = None,
|
||||
experimental_allow_widgets: bool = False,
|
||||
hash_funcs: HashFuncsDict | None = None,
|
||||
):
|
||||
return self._decorator(
|
||||
func,
|
||||
ttl=ttl,
|
||||
max_entries=max_entries,
|
||||
persist=persist,
|
||||
show_spinner=show_spinner,
|
||||
experimental_allow_widgets=experimental_allow_widgets,
|
||||
hash_funcs=hash_funcs,
|
||||
)
|
||||
|
||||
def _decorator(
|
||||
self,
|
||||
func: F | None = None,
|
||||
*,
|
||||
ttl: float | timedelta | str | None,
|
||||
max_entries: int | None,
|
||||
show_spinner: bool | str,
|
||||
persist: CachePersistType | bool,
|
||||
experimental_allow_widgets: bool,
|
||||
hash_funcs: HashFuncsDict | None = None,
|
||||
):
|
||||
"""Decorator to cache functions that return data (e.g. dataframe transforms, database queries, ML inference).
|
||||
|
||||
Cached objects are stored in "pickled" form, which means that the return
|
||||
value of a cached function must be pickleable. Each caller of the cached
|
||||
function gets its own copy of the cached data.
|
||||
|
||||
You can clear a function's cache with ``func.clear()`` or clear the entire
|
||||
cache with ``st.cache_data.clear()``.
|
||||
|
||||
A function's arguments must be hashable to cache it. If you have an
|
||||
unhashable argument (like a database connection) or an argument you
|
||||
want to exclude from caching, use an underscore prefix in the argument
|
||||
name. In this case, Streamlit will return a cached value when all other
|
||||
arguments match a previous function call. Alternatively, you can
|
||||
declare custom hashing functions with ``hash_funcs``.
|
||||
|
||||
To cache global resources, use ``st.cache_resource`` instead. Learn more
|
||||
about caching at https://docs.streamlit.io/develop/concepts/architecture/caching.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
The function to cache. Streamlit hashes the function's source code.
|
||||
|
||||
ttl : float, timedelta, str, or None
|
||||
The maximum time to keep an entry in the cache. Can be one of:
|
||||
|
||||
- ``None`` if cache entries should never expire (default).
|
||||
- A number specifying the time in seconds.
|
||||
- A string specifying the time in a format supported by `Pandas's
|
||||
Timedelta constructor <https://pandas.pydata.org/docs/reference/api/pandas.Timedelta.html>`_,
|
||||
e.g. ``"1d"``, ``"1.5 days"``, or ``"1h23s"``.
|
||||
- A ``timedelta`` object from `Python's built-in datetime library
|
||||
<https://docs.python.org/3/library/datetime.html#timedelta-objects>`_,
|
||||
e.g. ``timedelta(days=1)``.
|
||||
|
||||
Note that ``ttl`` will be ignored if ``persist="disk"`` or ``persist=True``.
|
||||
|
||||
max_entries : int or None
|
||||
The maximum number of entries to keep in the cache, or None
|
||||
for an unbounded cache. When a new entry is added to a full cache,
|
||||
the oldest cached entry will be removed. Defaults to None.
|
||||
|
||||
show_spinner : bool or str
|
||||
Enable the spinner. Default is True to show a spinner when there is
|
||||
a "cache miss" and the cached data is being created. If string,
|
||||
value of show_spinner param will be used for spinner text.
|
||||
|
||||
persist : "disk", bool, or None
|
||||
Optional location to persist cached data to. Passing "disk" (or True)
|
||||
will persist the cached data to the local disk. None (or False) will disable
|
||||
persistence. The default is None.
|
||||
|
||||
experimental_allow_widgets : bool
|
||||
Allow widgets to be used in the cached function. Defaults to False.
|
||||
|
||||
hash_funcs : dict or None
|
||||
Mapping of types or fully qualified names to hash functions.
|
||||
This is used to override the behavior of the hasher inside Streamlit's
|
||||
caching mechanism: when the hasher encounters an object, it will first
|
||||
check to see if its type matches a key in this dict and, if so, will use
|
||||
the provided function to generate a hash for it. See below for an example
|
||||
of how this can be used.
|
||||
|
||||
.. deprecated::
|
||||
The cached widget replay functionality was removed in 1.38. Please
|
||||
remove the ``experimental_allow_widgets`` parameter from your
|
||||
caching decorators. This parameter will be removed in a future
|
||||
version.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_data
|
||||
... def fetch_and_clean_data(url):
|
||||
... # Fetch data from URL here, and then clean it up.
|
||||
... return data
|
||||
>>>
|
||||
>>> d1 = fetch_and_clean_data(DATA_URL_1)
|
||||
>>> # Actually executes the function, since this is the first time it was
|
||||
>>> # encountered.
|
||||
>>>
|
||||
>>> d2 = fetch_and_clean_data(DATA_URL_1)
|
||||
>>> # Does not execute the function. Instead, returns its previously computed
|
||||
>>> # value. This means that now the data in d1 is the same as in d2.
|
||||
>>>
|
||||
>>> d3 = fetch_and_clean_data(DATA_URL_2)
|
||||
>>> # This is a different URL, so the function executes.
|
||||
|
||||
To set the ``persist`` parameter, use this command as follows:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_data(persist="disk")
|
||||
... def fetch_and_clean_data(url):
|
||||
... # Fetch data from URL here, and then clean it up.
|
||||
... return data
|
||||
|
||||
By default, all parameters to a cached function must be hashable.
|
||||
Any parameter whose name begins with ``_`` will not be hashed. You can use
|
||||
this as an "escape hatch" for parameters that are not hashable:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_data
|
||||
... def fetch_and_clean_data(_db_connection, num_rows):
|
||||
... # Fetch data from _db_connection here, and then clean it up.
|
||||
... return data
|
||||
>>>
|
||||
>>> connection = make_database_connection()
|
||||
>>> d1 = fetch_and_clean_data(connection, num_rows=10)
|
||||
>>> # Actually executes the function, since this is the first time it was
|
||||
>>> # encountered.
|
||||
>>>
|
||||
>>> another_connection = make_database_connection()
|
||||
>>> d2 = fetch_and_clean_data(another_connection, num_rows=10)
|
||||
>>> # Does not execute the function. Instead, returns its previously computed
|
||||
>>> # value - even though the _database_connection parameter was different
|
||||
>>> # in both calls.
|
||||
|
||||
A cached function's cache can be procedurally cleared:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_data
|
||||
... def fetch_and_clean_data(_db_connection, num_rows):
|
||||
... # Fetch data from _db_connection here, and then clean it up.
|
||||
... return data
|
||||
>>>
|
||||
>>> fetch_and_clean_data.clear(_db_connection, 50)
|
||||
>>> # Clear the cached entry for the arguments provided.
|
||||
>>>
|
||||
>>> fetch_and_clean_data.clear()
|
||||
>>> # Clear all cached entries for this function.
|
||||
|
||||
To override the default hashing behavior, pass a custom hash function.
|
||||
You can do that by mapping a type (e.g. ``datetime.datetime``) to a hash
|
||||
function (``lambda dt: dt.isoformat()``) like this:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>> import datetime
|
||||
>>>
|
||||
>>> @st.cache_data(hash_funcs={datetime.datetime: lambda dt: dt.isoformat()})
|
||||
... def convert_to_utc(dt: datetime.datetime):
|
||||
... return dt.astimezone(datetime.timezone.utc)
|
||||
|
||||
Alternatively, you can map the type's fully-qualified name
|
||||
(e.g. ``"datetime.datetime"``) to the hash function instead:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>> import datetime
|
||||
>>>
|
||||
>>> @st.cache_data(hash_funcs={"datetime.datetime": lambda dt: dt.isoformat()})
|
||||
... def convert_to_utc(dt: datetime.datetime):
|
||||
... return dt.astimezone(datetime.timezone.utc)
|
||||
|
||||
"""
|
||||
|
||||
# Parse our persist value into a string
|
||||
persist_string: CachePersistType
|
||||
if persist is True:
|
||||
persist_string = "disk"
|
||||
elif persist is False:
|
||||
persist_string = None
|
||||
else:
|
||||
persist_string = persist
|
||||
|
||||
if persist_string not in (None, "disk"):
|
||||
# We'll eventually have more persist options.
|
||||
raise StreamlitAPIException(
|
||||
f"Unsupported persist option '{persist}'. Valid values are 'disk' or None."
|
||||
)
|
||||
|
||||
if experimental_allow_widgets:
|
||||
show_widget_replay_deprecation("cache_data")
|
||||
|
||||
def wrapper(f):
|
||||
return make_cached_func_wrapper(
|
||||
CachedDataFuncInfo(
|
||||
func=f,
|
||||
persist=persist_string,
|
||||
show_spinner=show_spinner,
|
||||
max_entries=max_entries,
|
||||
ttl=ttl,
|
||||
hash_funcs=hash_funcs,
|
||||
)
|
||||
)
|
||||
|
||||
if func is None:
|
||||
return wrapper
|
||||
|
||||
return make_cached_func_wrapper(
|
||||
CachedDataFuncInfo(
|
||||
func=cast("types.FunctionType", func),
|
||||
persist=persist_string,
|
||||
show_spinner=show_spinner,
|
||||
max_entries=max_entries,
|
||||
ttl=ttl,
|
||||
hash_funcs=hash_funcs,
|
||||
)
|
||||
)
|
||||
|
||||
@gather_metrics("clear_data_caches")
|
||||
def clear(self) -> None:
|
||||
"""Clear all in-memory and on-disk data caches."""
|
||||
_data_caches.clear_all()
|
||||
|
||||
|
||||
class DataCache(Cache):
|
||||
"""Manages cached values for a single st.cache_data function."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
storage: CacheStorage,
|
||||
persist: CachePersistType,
|
||||
max_entries: int | None,
|
||||
ttl_seconds: float | None,
|
||||
display_name: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.display_name = display_name
|
||||
self.storage = storage
|
||||
self.ttl_seconds = ttl_seconds
|
||||
self.max_entries = max_entries
|
||||
self.persist = persist
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
if isinstance(self.storage, CacheStatsProvider):
|
||||
return self.storage.get_stats()
|
||||
return []
|
||||
|
||||
def read_result(self, key: str) -> CachedResult:
|
||||
"""Read a value and messages from the cache. Raise `CacheKeyNotFoundError`
|
||||
if the value doesn't exist, and `CacheError` if the value exists but can't
|
||||
be unpickled.
|
||||
"""
|
||||
try:
|
||||
pickled_entry = self.storage.get(key)
|
||||
except CacheStorageKeyNotFoundError as e:
|
||||
raise CacheKeyNotFoundError(str(e)) from e
|
||||
except CacheStorageError as e:
|
||||
raise CacheError(str(e)) from e
|
||||
|
||||
try:
|
||||
entry = pickle.loads(pickled_entry)
|
||||
if not isinstance(entry, CachedResult):
|
||||
# Loaded an old cache file format, remove it and let the caller
|
||||
# rerun the function.
|
||||
self.storage.delete(key)
|
||||
raise CacheKeyNotFoundError()
|
||||
return entry
|
||||
except pickle.UnpicklingError as exc:
|
||||
raise CacheError(f"Failed to unpickle {key}") from exc
|
||||
|
||||
@gather_metrics("_cache_data_object")
|
||||
def write_result(self, key: str, value: Any, messages: list[MsgData]) -> None:
|
||||
"""Write a value and associated messages to the cache.
|
||||
The value must be pickleable.
|
||||
"""
|
||||
try:
|
||||
main_id = st._main.id
|
||||
sidebar_id = st.sidebar.id
|
||||
entry = CachedResult(value, messages, main_id, sidebar_id)
|
||||
pickled_entry = pickle.dumps(entry)
|
||||
except (pickle.PicklingError, TypeError) as exc:
|
||||
raise CacheError(f"Failed to pickle {key}") from exc
|
||||
self.storage.set(key, pickled_entry)
|
||||
|
||||
def _clear(self, key: str | None = None) -> None:
|
||||
if not key:
|
||||
self.storage.clear()
|
||||
else:
|
||||
self.storage.delete(key)
|
||||
@@ -0,0 +1,142 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from streamlit import type_util
|
||||
from streamlit.errors import MarkdownFormattedException, StreamlitAPIException
|
||||
from streamlit.runtime.caching.cache_type import CacheType, get_decorator_api_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import FunctionType
|
||||
|
||||
CACHE_DOCS_URL = "https://docs.streamlit.io/develop/concepts/architecture/caching"
|
||||
|
||||
|
||||
def get_cached_func_name_md(func: Any) -> str:
|
||||
"""Get markdown representation of the function name."""
|
||||
if hasattr(func, "__name__"):
|
||||
return f"`{func.__name__}()`"
|
||||
elif hasattr(type(func), "__name__"):
|
||||
return f"`{type(func).__name__}`"
|
||||
return f"`{type(func)}`"
|
||||
|
||||
|
||||
def get_return_value_type(return_value: Any) -> str:
|
||||
if hasattr(return_value, "__module__") and hasattr(type(return_value), "__name__"):
|
||||
return f"`{return_value.__module__}.{type(return_value).__name__}`"
|
||||
return get_cached_func_name_md(return_value)
|
||||
|
||||
|
||||
class UnhashableTypeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnhashableParamError(StreamlitAPIException):
|
||||
def __init__(
|
||||
self,
|
||||
cache_type: CacheType,
|
||||
func: FunctionType,
|
||||
arg_name: str | None,
|
||||
arg_value: Any,
|
||||
orig_exc: BaseException,
|
||||
):
|
||||
msg = self._create_message(cache_type, func, arg_name, arg_value)
|
||||
super().__init__(msg)
|
||||
self.with_traceback(orig_exc.__traceback__)
|
||||
|
||||
@staticmethod
|
||||
def _create_message(
|
||||
cache_type: CacheType,
|
||||
func: FunctionType,
|
||||
arg_name: str | None,
|
||||
arg_value: Any,
|
||||
) -> str:
|
||||
arg_name_str = arg_name if arg_name is not None else "(unnamed)"
|
||||
arg_type = type_util.get_fqn_type(arg_value)
|
||||
func_name = func.__name__
|
||||
arg_replacement_name = f"_{arg_name}" if arg_name is not None else "_arg"
|
||||
|
||||
return (
|
||||
f"""
|
||||
Cannot hash argument '{arg_name_str}' (of type `{arg_type}`) in '{func_name}'.
|
||||
|
||||
To address this, you can tell Streamlit not to hash this argument by adding a
|
||||
leading underscore to the argument's name in the function signature:
|
||||
|
||||
```
|
||||
@st.{get_decorator_api_name(cache_type)}
|
||||
def {func_name}({arg_replacement_name}, ...):
|
||||
...
|
||||
```
|
||||
"""
|
||||
).strip("\n")
|
||||
|
||||
|
||||
class CacheKeyNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CacheError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CacheReplayClosureError(StreamlitAPIException):
|
||||
def __init__(
|
||||
self,
|
||||
cache_type: CacheType,
|
||||
cached_func: FunctionType,
|
||||
):
|
||||
func_name = get_cached_func_name_md(cached_func)
|
||||
decorator_name = get_decorator_api_name(cache_type)
|
||||
|
||||
msg = (
|
||||
f"""
|
||||
While running {func_name}, a streamlit element is called on some layout block
|
||||
created outside the function. This is incompatible with replaying the cached
|
||||
effect of that element, because the the referenced block might not exist when
|
||||
the replay happens.
|
||||
|
||||
How to fix this:
|
||||
* Move the creation of $THING inside {func_name}.
|
||||
* Move the call to the streamlit element outside of {func_name}.
|
||||
* Remove the `@st.{decorator_name}` decorator from {func_name}.
|
||||
"""
|
||||
).strip("\n")
|
||||
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class UnserializableReturnValueError(MarkdownFormattedException):
|
||||
def __init__(self, func: FunctionType, return_value: FunctionType):
|
||||
MarkdownFormattedException.__init__(
|
||||
self,
|
||||
f"""
|
||||
Cannot serialize the return value (of type {get_return_value_type(return_value)})
|
||||
in {get_cached_func_name_md(func)}. `st.cache_data` uses
|
||||
[pickle](https://docs.python.org/3/library/pickle.html) to serialize the
|
||||
function's return value and safely store it in the cache
|
||||
without mutating the original object. Please convert the return value to a
|
||||
pickle-serializable type. If you want to cache unserializable objects such
|
||||
as database connections or Tensorflow sessions, use `st.cache_resource`
|
||||
instead (see [our docs]({CACHE_DOCS_URL}) for differences).""",
|
||||
)
|
||||
|
||||
|
||||
class UnevaluatedDataFrameError(StreamlitAPIException):
|
||||
"""Used to display a message about uncollected dataframe being used."""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,527 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""@st.cache_resource implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, Callable, Final, TypeVar, cast, overload
|
||||
|
||||
from cachetools import TTLCache
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import streamlit as st
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.caching import cache_utils
|
||||
from streamlit.runtime.caching.cache_errors import CacheKeyNotFoundError
|
||||
from streamlit.runtime.caching.cache_type import CacheType
|
||||
from streamlit.runtime.caching.cache_utils import (
|
||||
Cache,
|
||||
CachedFuncInfo,
|
||||
make_cached_func_wrapper,
|
||||
)
|
||||
from streamlit.runtime.caching.cached_message_replay import (
|
||||
CachedMessageReplayContext,
|
||||
CachedResult,
|
||||
MsgData,
|
||||
show_widget_replay_deprecation,
|
||||
)
|
||||
from streamlit.runtime.metrics_util import gather_metrics
|
||||
from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats
|
||||
from streamlit.time_util import time_to_seconds
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import types
|
||||
from datetime import timedelta
|
||||
|
||||
from streamlit.runtime.caching.hashing import HashFuncsDict
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
|
||||
CACHE_RESOURCE_MESSAGE_REPLAY_CTX = CachedMessageReplayContext(CacheType.RESOURCE)
|
||||
|
||||
ValidateFunc: TypeAlias = Callable[[Any], bool]
|
||||
|
||||
|
||||
def _equal_validate_funcs(a: ValidateFunc | None, b: ValidateFunc | None) -> bool:
|
||||
"""True if the two validate functions are equal for the purposes of
|
||||
determining whether a given function cache needs to be recreated.
|
||||
"""
|
||||
# To "properly" test for function equality here, we'd need to compare function bytecode.
|
||||
# For performance reasons, We've decided not to do that for now.
|
||||
return (a is None and b is None) or (a is not None and b is not None)
|
||||
|
||||
|
||||
class ResourceCaches(CacheStatsProvider):
|
||||
"""Manages all ResourceCache instances."""
|
||||
|
||||
def __init__(self):
|
||||
self._caches_lock = threading.Lock()
|
||||
self._function_caches: dict[str, ResourceCache] = {}
|
||||
|
||||
def get_cache(
|
||||
self,
|
||||
key: str,
|
||||
display_name: str,
|
||||
max_entries: int | float | None,
|
||||
ttl: float | timedelta | str | None,
|
||||
validate: ValidateFunc | None,
|
||||
) -> ResourceCache:
|
||||
"""Return the mem cache for the given key.
|
||||
|
||||
If it doesn't exist, create a new one with the given params.
|
||||
"""
|
||||
if max_entries is None:
|
||||
max_entries = math.inf
|
||||
|
||||
ttl_seconds = time_to_seconds(ttl)
|
||||
|
||||
# Get the existing cache, if it exists, and validate that its params
|
||||
# haven't changed.
|
||||
with self._caches_lock:
|
||||
cache = self._function_caches.get(key)
|
||||
if (
|
||||
cache is not None
|
||||
and cache.ttl_seconds == ttl_seconds
|
||||
and cache.max_entries == max_entries
|
||||
and _equal_validate_funcs(cache.validate, validate)
|
||||
):
|
||||
return cache
|
||||
|
||||
# Create a new cache object and put it in our dict
|
||||
_LOGGER.debug("Creating new ResourceCache (key=%s)", key)
|
||||
cache = ResourceCache(
|
||||
key=key,
|
||||
display_name=display_name,
|
||||
max_entries=max_entries,
|
||||
ttl_seconds=ttl_seconds,
|
||||
validate=validate,
|
||||
)
|
||||
self._function_caches[key] = cache
|
||||
return cache
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Clear all resource caches."""
|
||||
with self._caches_lock:
|
||||
self._function_caches = {}
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
with self._caches_lock:
|
||||
# Shallow-clone our caches. We don't want to hold the global
|
||||
# lock during stats-gathering.
|
||||
function_caches = self._function_caches.copy()
|
||||
|
||||
stats: list[CacheStat] = []
|
||||
for cache in function_caches.values():
|
||||
stats.extend(cache.get_stats())
|
||||
return group_stats(stats)
|
||||
|
||||
|
||||
# Singleton ResourceCaches instance
|
||||
_resource_caches = ResourceCaches()
|
||||
|
||||
|
||||
def get_resource_cache_stats_provider() -> CacheStatsProvider:
|
||||
"""Return the StatsProvider for all @st.cache_resource functions."""
|
||||
return _resource_caches
|
||||
|
||||
|
||||
class CachedResourceFuncInfo(CachedFuncInfo):
|
||||
"""Implements the CachedFuncInfo interface for @st.cache_resource."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: types.FunctionType,
|
||||
show_spinner: bool | str,
|
||||
max_entries: int | None,
|
||||
ttl: float | timedelta | str | None,
|
||||
validate: ValidateFunc | None,
|
||||
hash_funcs: HashFuncsDict | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
func,
|
||||
show_spinner=show_spinner,
|
||||
hash_funcs=hash_funcs,
|
||||
)
|
||||
self.max_entries = max_entries
|
||||
self.ttl = ttl
|
||||
self.validate = validate
|
||||
|
||||
@property
|
||||
def cache_type(self) -> CacheType:
|
||||
return CacheType.RESOURCE
|
||||
|
||||
@property
|
||||
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
|
||||
return CACHE_RESOURCE_MESSAGE_REPLAY_CTX
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""A human-readable name for the cached function."""
|
||||
return f"{self.func.__module__}.{self.func.__qualname__}"
|
||||
|
||||
def get_function_cache(self, function_key: str) -> Cache:
|
||||
return _resource_caches.get_cache(
|
||||
key=function_key,
|
||||
display_name=self.display_name,
|
||||
max_entries=self.max_entries,
|
||||
ttl=self.ttl,
|
||||
validate=self.validate,
|
||||
)
|
||||
|
||||
|
||||
class CacheResourceAPI:
|
||||
"""Implements the public st.cache_resource API: the @st.cache_resource decorator,
|
||||
and st.cache_resource.clear().
|
||||
"""
|
||||
|
||||
def __init__(self, decorator_metric_name: str):
|
||||
"""Create a CacheResourceAPI instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
decorator_metric_name
|
||||
The metric name to record for decorator usage.
|
||||
"""
|
||||
|
||||
# Parameterize the decorator metric name.
|
||||
# (Ignore spurious mypy complaints - https://github.com/python/mypy/issues/2427)
|
||||
self._decorator = gather_metrics(decorator_metric_name, self._decorator) # type: ignore
|
||||
|
||||
# Type-annotate the decorator function.
|
||||
# (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories)
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
# Bare decorator usage
|
||||
@overload
|
||||
def __call__(self, func: F) -> F: ...
|
||||
|
||||
# Decorator with arguments
|
||||
@overload
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
ttl: float | timedelta | str | None = None,
|
||||
max_entries: int | None = None,
|
||||
show_spinner: bool | str = True,
|
||||
validate: ValidateFunc | None = None,
|
||||
experimental_allow_widgets: bool = False,
|
||||
hash_funcs: HashFuncsDict | None = None,
|
||||
) -> Callable[[F], F]: ...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
func: F | None = None,
|
||||
*,
|
||||
ttl: float | timedelta | str | None = None,
|
||||
max_entries: int | None = None,
|
||||
show_spinner: bool | str = True,
|
||||
validate: ValidateFunc | None = None,
|
||||
experimental_allow_widgets: bool = False,
|
||||
hash_funcs: HashFuncsDict | None = None,
|
||||
):
|
||||
return self._decorator(
|
||||
func,
|
||||
ttl=ttl,
|
||||
max_entries=max_entries,
|
||||
show_spinner=show_spinner,
|
||||
validate=validate,
|
||||
experimental_allow_widgets=experimental_allow_widgets,
|
||||
hash_funcs=hash_funcs,
|
||||
)
|
||||
|
||||
def _decorator(
|
||||
self,
|
||||
func: F | None,
|
||||
*,
|
||||
ttl: float | timedelta | str | None,
|
||||
max_entries: int | None,
|
||||
show_spinner: bool | str,
|
||||
validate: ValidateFunc | None,
|
||||
experimental_allow_widgets: bool,
|
||||
hash_funcs: HashFuncsDict | None = None,
|
||||
):
|
||||
"""Decorator to cache functions that return global resources (e.g. database connections, ML models).
|
||||
|
||||
Cached objects are shared across all users, sessions, and reruns. They
|
||||
must be thread-safe because they can be accessed from multiple threads
|
||||
concurrently. If thread safety is an issue, consider using ``st.session_state``
|
||||
to store resources per session instead.
|
||||
|
||||
You can clear a function's cache with ``func.clear()`` or clear the entire
|
||||
cache with ``st.cache_resource.clear()``.
|
||||
|
||||
A function's arguments must be hashable to cache it. If you have an
|
||||
unhashable argument (like a database connection) or an argument you
|
||||
want to exclude from caching, use an underscore prefix in the argument
|
||||
name. In this case, Streamlit will return a cached value when all other
|
||||
arguments match a previous function call. Alternatively, you can
|
||||
declare custom hashing functions with ``hash_funcs``.
|
||||
|
||||
To cache data, use ``st.cache_data`` instead. Learn more about caching at
|
||||
https://docs.streamlit.io/develop/concepts/architecture/caching.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
The function that creates the cached resource. Streamlit hashes the
|
||||
function's source code.
|
||||
|
||||
ttl : float, timedelta, str, or None
|
||||
The maximum time to keep an entry in the cache. Can be one of:
|
||||
|
||||
- ``None`` if cache entries should never expire (default).
|
||||
- A number specifying the time in seconds.
|
||||
- A string specifying the time in a format supported by `Pandas's
|
||||
Timedelta constructor <https://pandas.pydata.org/docs/reference/api/pandas.Timedelta.html>`_,
|
||||
e.g. ``"1d"``, ``"1.5 days"``, or ``"1h23s"``.
|
||||
- A ``timedelta`` object from `Python's built-in datetime library
|
||||
<https://docs.python.org/3/library/datetime.html#timedelta-objects>`_,
|
||||
e.g. ``timedelta(days=1)``.
|
||||
|
||||
max_entries : int or None
|
||||
The maximum number of entries to keep in the cache, or None
|
||||
for an unbounded cache. When a new entry is added to a full cache,
|
||||
the oldest cached entry will be removed. Defaults to None.
|
||||
|
||||
show_spinner : bool or str
|
||||
Enable the spinner. Default is True to show a spinner when there is
|
||||
a "cache miss" and the cached resource is being created. If string,
|
||||
value of show_spinner param will be used for spinner text.
|
||||
|
||||
validate : callable or None
|
||||
An optional validation function for cached data. ``validate`` is called
|
||||
each time the cached value is accessed. It receives the cached value as
|
||||
its only parameter and it must return a boolean. If ``validate`` returns
|
||||
False, the current cached value is discarded, and the decorated function
|
||||
is called to compute a new value. This is useful e.g. to check the
|
||||
health of database connections.
|
||||
|
||||
experimental_allow_widgets : bool
|
||||
Allow widgets to be used in the cached function. Defaults to False.
|
||||
|
||||
hash_funcs : dict or None
|
||||
Mapping of types or fully qualified names to hash functions.
|
||||
This is used to override the behavior of the hasher inside Streamlit's
|
||||
caching mechanism: when the hasher encounters an object, it will first
|
||||
check to see if its type matches a key in this dict and, if so, will use
|
||||
the provided function to generate a hash for it. See below for an example
|
||||
of how this can be used.
|
||||
|
||||
.. deprecated::
|
||||
The cached widget replay functionality was removed in 1.38. Please
|
||||
remove the ``experimental_allow_widgets`` parameter from your
|
||||
caching decorators. This parameter will be removed in a future
|
||||
version.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_resource
|
||||
... def get_database_session(url):
|
||||
... # Create a database session object that points to the URL.
|
||||
... return session
|
||||
>>>
|
||||
>>> s1 = get_database_session(SESSION_URL_1)
|
||||
>>> # Actually executes the function, since this is the first time it was
|
||||
>>> # encountered.
|
||||
>>>
|
||||
>>> s2 = get_database_session(SESSION_URL_1)
|
||||
>>> # Does not execute the function. Instead, returns its previously computed
|
||||
>>> # value. This means that now the connection object in s1 is the same as in s2.
|
||||
>>>
|
||||
>>> s3 = get_database_session(SESSION_URL_2)
|
||||
>>> # This is a different URL, so the function executes.
|
||||
|
||||
By default, all parameters to a cache_resource function must be hashable.
|
||||
Any parameter whose name begins with ``_`` will not be hashed. You can use
|
||||
this as an "escape hatch" for parameters that are not hashable:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_resource
|
||||
... def get_database_session(_sessionmaker, url):
|
||||
... # Create a database connection object that points to the URL.
|
||||
... return connection
|
||||
>>>
|
||||
>>> s1 = get_database_session(create_sessionmaker(), DATA_URL_1)
|
||||
>>> # Actually executes the function, since this is the first time it was
|
||||
>>> # encountered.
|
||||
>>>
|
||||
>>> s2 = get_database_session(create_sessionmaker(), DATA_URL_1)
|
||||
>>> # Does not execute the function. Instead, returns its previously computed
|
||||
>>> # value - even though the _sessionmaker parameter was different
|
||||
>>> # in both calls.
|
||||
|
||||
A cache_resource function's cache can be procedurally cleared:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_resource
|
||||
... def get_database_session(_sessionmaker, url):
|
||||
... # Create a database connection object that points to the URL.
|
||||
... return connection
|
||||
>>>
|
||||
>>> fetch_and_clean_data.clear(_sessionmaker, "https://streamlit.io/")
|
||||
>>> # Clear the cached entry for the arguments provided.
|
||||
>>>
|
||||
>>> get_database_session.clear()
|
||||
>>> # Clear all cached entries for this function.
|
||||
|
||||
To override the default hashing behavior, pass a custom hash function.
|
||||
You can do that by mapping a type (e.g. ``Person``) to a hash
|
||||
function (``str``) like this:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>> from pydantic import BaseModel
|
||||
>>>
|
||||
>>> class Person(BaseModel):
|
||||
... name: str
|
||||
>>>
|
||||
>>> @st.cache_resource(hash_funcs={Person: str})
|
||||
... def get_person_name(person: Person):
|
||||
... return person.name
|
||||
|
||||
Alternatively, you can map the type's fully-qualified name
|
||||
(e.g. ``"__main__.Person"``) to the hash function instead:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>> from pydantic import BaseModel
|
||||
>>>
|
||||
>>> class Person(BaseModel):
|
||||
... name: str
|
||||
>>>
|
||||
>>> @st.cache_resource(hash_funcs={"__main__.Person": str})
|
||||
... def get_person_name(person: Person):
|
||||
... return person.name
|
||||
"""
|
||||
if experimental_allow_widgets:
|
||||
show_widget_replay_deprecation("cache_resource")
|
||||
|
||||
# Support passing the params via function decorator, e.g.
|
||||
# @st.cache_resource(show_spinner=False)
|
||||
if func is None:
|
||||
return lambda f: make_cached_func_wrapper(
|
||||
CachedResourceFuncInfo(
|
||||
func=f,
|
||||
show_spinner=show_spinner,
|
||||
max_entries=max_entries,
|
||||
ttl=ttl,
|
||||
validate=validate,
|
||||
hash_funcs=hash_funcs,
|
||||
)
|
||||
)
|
||||
|
||||
return make_cached_func_wrapper(
|
||||
CachedResourceFuncInfo(
|
||||
func=cast("types.FunctionType", func),
|
||||
show_spinner=show_spinner,
|
||||
max_entries=max_entries,
|
||||
ttl=ttl,
|
||||
validate=validate,
|
||||
hash_funcs=hash_funcs,
|
||||
)
|
||||
)
|
||||
|
||||
@gather_metrics("clear_resource_caches")
|
||||
def clear(self) -> None:
|
||||
"""Clear all cache_resource caches."""
|
||||
_resource_caches.clear_all()
|
||||
|
||||
|
||||
class ResourceCache(Cache):
|
||||
"""Manages cached values for a single st.cache_resource function."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
max_entries: float,
|
||||
ttl_seconds: float,
|
||||
validate: ValidateFunc | None,
|
||||
display_name: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.display_name = display_name
|
||||
self._mem_cache: TTLCache[str, CachedResult] = TTLCache(
|
||||
maxsize=max_entries, ttl=ttl_seconds, timer=cache_utils.TTLCACHE_TIMER
|
||||
)
|
||||
self._mem_cache_lock = threading.Lock()
|
||||
self.validate = validate
|
||||
|
||||
@property
|
||||
def max_entries(self) -> float:
|
||||
return self._mem_cache.maxsize
|
||||
|
||||
@property
|
||||
def ttl_seconds(self) -> float:
|
||||
return self._mem_cache.ttl
|
||||
|
||||
def read_result(self, key: str) -> CachedResult:
|
||||
"""Read a value and associated messages from the cache.
|
||||
Raise `CacheKeyNotFoundError` if the value doesn't exist.
|
||||
"""
|
||||
with self._mem_cache_lock:
|
||||
if key not in self._mem_cache:
|
||||
# key does not exist in cache.
|
||||
raise CacheKeyNotFoundError()
|
||||
|
||||
result = self._mem_cache[key]
|
||||
|
||||
if self.validate is not None and not self.validate(result.value):
|
||||
# Validate failed: delete the entry and raise an error.
|
||||
del self._mem_cache[key]
|
||||
raise CacheKeyNotFoundError()
|
||||
|
||||
return result
|
||||
|
||||
@gather_metrics("_cache_resource_object")
|
||||
def write_result(self, key: str, value: Any, messages: list[MsgData]) -> None:
|
||||
"""Write a value and associated messages to the cache."""
|
||||
main_id = st._main.id
|
||||
sidebar_id = st.sidebar.id
|
||||
|
||||
with self._mem_cache_lock:
|
||||
self._mem_cache[key] = CachedResult(value, messages, main_id, sidebar_id)
|
||||
|
||||
def _clear(self, key: str | None = None) -> None:
|
||||
with self._mem_cache_lock:
|
||||
if key is None:
|
||||
self._mem_cache.clear()
|
||||
elif key in self._mem_cache:
|
||||
del self._mem_cache[key]
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
# Shallow clone our cache. Computing item sizes is potentially
|
||||
# expensive, and we want to minimize the time we spend holding
|
||||
# the lock.
|
||||
with self._mem_cache_lock:
|
||||
cache_entries = list(self._mem_cache.values())
|
||||
|
||||
# Lazy-load vendored package to prevent import of numpy
|
||||
from streamlit.vendor.pympler.asizeof import asizeof
|
||||
|
||||
return [
|
||||
CacheStat(
|
||||
category_name="st_cache_resource",
|
||||
cache_name=self.display_name,
|
||||
byte_length=asizeof(entry),
|
||||
)
|
||||
for entry in cache_entries
|
||||
]
|
||||
@@ -0,0 +1,33 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
|
||||
|
||||
class CacheType(enum.Enum):
|
||||
"""The function cache types we implement."""
|
||||
|
||||
DATA = "DATA"
|
||||
RESOURCE = "RESOURCE"
|
||||
|
||||
|
||||
def get_decorator_api_name(cache_type: CacheType) -> str:
|
||||
"""Return the name of the public decorator API for the given CacheType."""
|
||||
if cache_type is CacheType.DATA:
|
||||
return "cache_data"
|
||||
if cache_type is CacheType.RESOURCE:
|
||||
return "cache_resource"
|
||||
raise RuntimeError(f"Unrecognized CacheType '{cache_type}'")
|
||||
@@ -0,0 +1,525 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Common cache logic shared by st.cache_data and st.cache_resource."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import threading
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Callable, Final
|
||||
|
||||
from streamlit import type_util
|
||||
from streamlit.dataframe_util import is_unevaluated_data_object
|
||||
from streamlit.elements.spinner import spinner
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.caching.cache_errors import (
|
||||
CacheError,
|
||||
CacheKeyNotFoundError,
|
||||
UnevaluatedDataFrameError,
|
||||
UnhashableParamError,
|
||||
UnhashableTypeError,
|
||||
UnserializableReturnValueError,
|
||||
get_cached_func_name_md,
|
||||
)
|
||||
from streamlit.runtime.caching.cached_message_replay import (
|
||||
CachedMessageReplayContext,
|
||||
CachedResult,
|
||||
MsgData,
|
||||
replay_cached_messages,
|
||||
)
|
||||
from streamlit.runtime.caching.hashing import HashFuncsDict, update_hash
|
||||
from streamlit.runtime.scriptrunner_utils.script_run_context import (
|
||||
in_cached_function,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import FunctionType
|
||||
|
||||
from streamlit.runtime.caching.cache_type import CacheType
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
# The timer function we use with TTLCache. This is the default timer func, but
|
||||
# is exposed here as a constant so that it can be patched in unit tests.
|
||||
TTLCACHE_TIMER = time.monotonic
|
||||
|
||||
|
||||
class Cache:
|
||||
"""Function cache interface. Caches persist across script runs."""
|
||||
|
||||
def __init__(self):
|
||||
self._value_locks: dict[str, threading.Lock] = defaultdict(threading.Lock)
|
||||
self._value_locks_lock = threading.Lock()
|
||||
|
||||
@abstractmethod
|
||||
def read_result(self, value_key: str) -> CachedResult:
|
||||
"""Read a value and associated messages from the cache.
|
||||
|
||||
Raises
|
||||
------
|
||||
CacheKeyNotFoundError
|
||||
Raised if value_key is not in the cache.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def write_result(self, value_key: str, value: Any, messages: list[MsgData]) -> None:
|
||||
"""Write a value and associated messages to the cache, overwriting any existing
|
||||
result that uses the value_key.
|
||||
"""
|
||||
# We *could* `del self._value_locks[value_key]` here, since nobody will be taking
|
||||
# a compute_value_lock for this value_key after the result is written.
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_value_lock(self, value_key: str) -> threading.Lock:
|
||||
"""Return the lock that should be held while computing a new cached value.
|
||||
In a popular app with a cache that hasn't been pre-warmed, many sessions may try
|
||||
to access a not-yet-cached value simultaneously. We use a lock to ensure that
|
||||
only one of those sessions computes the value, and the others block until
|
||||
the value is computed.
|
||||
"""
|
||||
with self._value_locks_lock:
|
||||
return self._value_locks[value_key]
|
||||
|
||||
def clear(self, key: str | None = None):
|
||||
"""Clear values from this cache.
|
||||
If no argument is passed, all items are cleared from the cache.
|
||||
A key can be passed to clear that key from the cache only.
|
||||
"""
|
||||
with self._value_locks_lock:
|
||||
if not key:
|
||||
self._value_locks.clear()
|
||||
elif key in self._value_locks:
|
||||
del self._value_locks[key]
|
||||
self._clear(key=key)
|
||||
|
||||
@abstractmethod
|
||||
def _clear(self, key: str | None = None) -> None:
|
||||
"""Subclasses must implement this to perform cache-clearing logic."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CachedFuncInfo:
|
||||
"""Encapsulates data for a cached function instance.
|
||||
|
||||
CachedFuncInfo instances are scoped to a single script run - they're not
|
||||
persistent.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: FunctionType,
|
||||
show_spinner: bool | str,
|
||||
hash_funcs: HashFuncsDict | None,
|
||||
):
|
||||
self.func = func
|
||||
self.show_spinner = show_spinner
|
||||
self.hash_funcs = hash_funcs
|
||||
|
||||
@property
|
||||
def cache_type(self) -> CacheType:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_function_cache(self, function_key: str) -> Cache:
|
||||
"""Get or create the function cache for the given key."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def make_cached_func_wrapper(info: CachedFuncInfo) -> Callable[..., Any]:
|
||||
"""Create a callable wrapper around a CachedFunctionInfo.
|
||||
|
||||
Calling the wrapper will return the cached value if it's already been
|
||||
computed, and will call the underlying function to compute and cache the
|
||||
value otherwise.
|
||||
|
||||
The wrapper also has a `clear` function that can be called to clear
|
||||
some or all of the wrapper's cached values.
|
||||
"""
|
||||
cached_func = CachedFunc(info)
|
||||
return functools.update_wrapper(cached_func, info.func)
|
||||
|
||||
|
||||
class BoundCachedFunc:
|
||||
"""A wrapper around a CachedFunc that binds it to a specific instance in case of
|
||||
decorated function is a class method.
|
||||
"""
|
||||
|
||||
def __init__(self, cached_func: CachedFunc, instance: Any):
|
||||
self._cached_func = cached_func
|
||||
self._instance = instance
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
return self._cached_func(self._instance, *args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<BoundCachedFunc: {self._cached_func._info.func} of {self._instance}>"
|
||||
|
||||
def clear(self, *args, **kwargs):
|
||||
if args or kwargs:
|
||||
# The instance is required as first parameter to allow
|
||||
# args to be correctly resolved to the parameter names:
|
||||
self._cached_func.clear(self._instance, *args, **kwargs)
|
||||
else:
|
||||
# if no args/kwargs are specified, we just want to clear the
|
||||
# entire cache of this method:
|
||||
self._cached_func.clear()
|
||||
|
||||
|
||||
class CachedFunc:
|
||||
def __init__(self, info: CachedFuncInfo):
|
||||
self._info = info
|
||||
self._function_key = _make_function_key(info.cache_type, info.func)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<CachedFunc: {self._info.func}>"
|
||||
|
||||
def __get__(self, instance, owner=None):
|
||||
"""CachedFunc implements descriptor protocol to support cache methods."""
|
||||
if instance is None:
|
||||
return self
|
||||
|
||||
return functools.update_wrapper(BoundCachedFunc(self, instance), self)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
"""The wrapper. We'll only call our underlying function on a cache miss."""
|
||||
|
||||
spinner_message: str | None = None
|
||||
if isinstance(self._info.show_spinner, str):
|
||||
spinner_message = self._info.show_spinner
|
||||
elif self._info.show_spinner is True:
|
||||
name = self._info.func.__qualname__
|
||||
if len(args) == 0 and len(kwargs) == 0:
|
||||
spinner_message = f"Running `{name}()`."
|
||||
else:
|
||||
spinner_message = f"Running `{name}(...)`."
|
||||
|
||||
return self._get_or_create_cached_value(args, kwargs, spinner_message)
|
||||
|
||||
def _get_or_create_cached_value(
|
||||
self,
|
||||
func_args: tuple[Any, ...],
|
||||
func_kwargs: dict[str, Any],
|
||||
spinner_message: str | None = None,
|
||||
) -> Any:
|
||||
# Retrieve the function's cache object. We must do this "just-in-time"
|
||||
# (as opposed to in the constructor), because caches can be invalidated
|
||||
# at any time.
|
||||
cache = self._info.get_function_cache(self._function_key)
|
||||
|
||||
# Generate the key for the cached value. This is based on the
|
||||
# arguments passed to the function.
|
||||
value_key = _make_value_key(
|
||||
cache_type=self._info.cache_type,
|
||||
func=self._info.func,
|
||||
func_args=func_args,
|
||||
func_kwargs=func_kwargs,
|
||||
hash_funcs=self._info.hash_funcs,
|
||||
)
|
||||
|
||||
with contextlib.suppress(CacheKeyNotFoundError):
|
||||
cached_result = cache.read_result(value_key)
|
||||
return self._handle_cache_hit(cached_result)
|
||||
|
||||
# only show spinner if there is a message to show and always only for the
|
||||
# outermost cache function if cache functions are nested, because the outermost
|
||||
# function has to wait for the inner functions anyways. This avoids surprising
|
||||
# users with slowdowned apps in case the inner functions are called very often,
|
||||
# which would lead to a ton of (empty/spinner) proto messages that will make the
|
||||
# app slow (see https://github.com/streamlit/streamlit/issues/9951). This is
|
||||
# basically like auto-setting "show_spinner=False" on the @st.cache decorators
|
||||
# on behalf of the user.
|
||||
is_nested_cache_function = in_cached_function.get()
|
||||
spinner_or_no_context = (
|
||||
spinner(spinner_message, _cache=True)
|
||||
if spinner_message is not None and not is_nested_cache_function
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
with spinner_or_no_context:
|
||||
return self._handle_cache_miss(cache, value_key, func_args, func_kwargs)
|
||||
|
||||
def _handle_cache_hit(self, result: CachedResult) -> Any:
|
||||
"""Handle a cache hit: replay the result's cached messages, and return its
|
||||
value.
|
||||
"""
|
||||
replay_cached_messages(
|
||||
result,
|
||||
self._info.cache_type,
|
||||
self._info.func,
|
||||
)
|
||||
return result.value
|
||||
|
||||
def _handle_cache_miss(
|
||||
self,
|
||||
cache: Cache,
|
||||
value_key: str,
|
||||
func_args: tuple[Any, ...],
|
||||
func_kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
"""Handle a cache miss: compute a new cached value, write it back to the cache,
|
||||
and return that newly-computed value.
|
||||
"""
|
||||
|
||||
# Implementation notes:
|
||||
# - We take a "compute_value_lock" before computing our value. This ensures that
|
||||
# multiple sessions don't try to compute the same value simultaneously.
|
||||
#
|
||||
# - We use a different lock for each value_key, as opposed to a single lock for
|
||||
# the entire cache, so that unrelated value computations don't block on each other.
|
||||
#
|
||||
# - When retrieving a cache entry that may not yet exist, we use a "double-checked locking"
|
||||
# strategy: first we try to retrieve the cache entry without taking a value lock. (This
|
||||
# happens in `_get_or_create_cached_value()`.) If that fails because the value hasn't
|
||||
# been computed yet, we take the value lock and then immediately try to retrieve cache entry
|
||||
# *again*, while holding the lock. If the cache entry exists at this point, it means that
|
||||
# another thread computed the value before us.
|
||||
#
|
||||
# This means that the happy path ("cache entry exists") is a wee bit faster because
|
||||
# no lock is acquired. But the unhappy path ("cache entry needs to be recomputed") is
|
||||
# a wee bit slower, because we do two lookups for the entry.
|
||||
|
||||
with cache.compute_value_lock(value_key):
|
||||
# We've acquired the lock - but another thread may have acquired it first
|
||||
# and already computed the value. So we need to test for a cache hit again,
|
||||
# before computing.
|
||||
try:
|
||||
cached_result = cache.read_result(value_key)
|
||||
# Another thread computed the value before us. Early exit!
|
||||
return self._handle_cache_hit(cached_result)
|
||||
except CacheKeyNotFoundError:
|
||||
# No cache hit -> we will call the cached function
|
||||
# below.
|
||||
pass
|
||||
|
||||
# We acquired the lock before any other thread. Compute the value!
|
||||
with self._info.cached_message_replay_ctx.calling_cached_function(
|
||||
self._info.func
|
||||
):
|
||||
computed_value = self._info.func(*func_args, **func_kwargs)
|
||||
|
||||
# We've computed our value, and now we need to write it back to the cache
|
||||
# along with any "replay messages" that were generated during value computation.
|
||||
messages = self._info.cached_message_replay_ctx._most_recent_messages
|
||||
try:
|
||||
cache.write_result(value_key, computed_value, messages)
|
||||
return computed_value
|
||||
except (CacheError, RuntimeError) as ex:
|
||||
# An exception was thrown while we tried to write to the cache. Report
|
||||
# it to the user. (We catch `RuntimeError` here because it will be
|
||||
# raised by Apache Spark if we do not collect dataframe before
|
||||
# using `st.cache_data`.)
|
||||
if is_unevaluated_data_object(computed_value):
|
||||
# If the returned value is an unevaluated dataframe, raise an error.
|
||||
# Unevaluated dataframes are not yet in the local memory, which also
|
||||
# means they cannot be properly cached (serialized).
|
||||
raise UnevaluatedDataFrameError(
|
||||
f"The function {get_cached_func_name_md(self._info.func)} is "
|
||||
"decorated with `st.cache_data` but it returns an unevaluated "
|
||||
f"data object of type `{type_util.get_fqn_type(computed_value)}`. "
|
||||
"Please convert the object to a serializable format "
|
||||
"(e.g. Pandas DataFrame) before returning it, so "
|
||||
"`st.cache_data` can serialize and cache it."
|
||||
) from ex
|
||||
raise UnserializableReturnValueError(
|
||||
return_value=computed_value, func=self._info.func
|
||||
)
|
||||
|
||||
def clear(self, *args, **kwargs):
|
||||
"""Clear the cached function's associated cache.
|
||||
|
||||
If no arguments are passed, Streamlit will clear all values cached for
|
||||
the function. If arguments are passed, Streamlit will clear the cached
|
||||
value for these arguments only.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
*args: Any
|
||||
Arguments of the cached functions.
|
||||
|
||||
**kwargs: Any
|
||||
Keyword arguments of the cached function.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> import streamlit as st
|
||||
>>> import time
|
||||
>>>
|
||||
>>> @st.cache_data
|
||||
>>> def foo(bar):
|
||||
>>> time.sleep(2)
|
||||
>>> st.write(f"Executed foo({bar}).")
|
||||
>>> return bar
|
||||
>>>
|
||||
>>> if st.button("Clear all cached values for `foo`", on_click=foo.clear):
|
||||
>>> foo.clear()
|
||||
>>>
|
||||
>>> if st.button("Clear the cached value of `foo(1)`"):
|
||||
>>> foo.clear(1)
|
||||
>>>
|
||||
>>> foo(1)
|
||||
>>> foo(2)
|
||||
|
||||
"""
|
||||
cache = self._info.get_function_cache(self._function_key)
|
||||
if args or kwargs:
|
||||
key = _make_value_key(
|
||||
cache_type=self._info.cache_type,
|
||||
func=self._info.func,
|
||||
func_args=args,
|
||||
func_kwargs=kwargs,
|
||||
hash_funcs=self._info.hash_funcs,
|
||||
)
|
||||
else:
|
||||
key = None
|
||||
cache.clear(key=key)
|
||||
|
||||
|
||||
def _make_value_key(
|
||||
cache_type: CacheType,
|
||||
func: FunctionType,
|
||||
func_args: tuple[Any, ...],
|
||||
func_kwargs: dict[str, Any],
|
||||
hash_funcs: HashFuncsDict | None,
|
||||
) -> str:
|
||||
"""Create the key for a value within a cache.
|
||||
|
||||
This key is generated from the function's arguments. All arguments
|
||||
will be hashed, except for those named with a leading "_".
|
||||
|
||||
Raises
|
||||
------
|
||||
StreamlitAPIException
|
||||
Raised (with a nicely-formatted explanation message) if we encounter
|
||||
an un-hashable arg.
|
||||
"""
|
||||
|
||||
# Create a (name, value) list of all *args and **kwargs passed to the
|
||||
# function.
|
||||
arg_pairs: list[tuple[str | None, Any]] = []
|
||||
for arg_idx in range(len(func_args)):
|
||||
arg_name = _get_positional_arg_name(func, arg_idx)
|
||||
arg_pairs.append((arg_name, func_args[arg_idx]))
|
||||
|
||||
for kw_name, kw_val in func_kwargs.items():
|
||||
# **kwargs ordering is preserved, per PEP 468
|
||||
# https://www.python.org/dev/peps/pep-0468/, so this iteration is
|
||||
# deterministic.
|
||||
arg_pairs.append((kw_name, kw_val))
|
||||
|
||||
# Create the hash from each arg value, except for those args whose name
|
||||
# starts with "_". (Underscore-prefixed args are deliberately excluded from
|
||||
# hashing.)
|
||||
args_hasher = hashlib.new("md5", usedforsecurity=False)
|
||||
for arg_name, arg_value in arg_pairs:
|
||||
if arg_name is not None and arg_name.startswith("_"):
|
||||
_LOGGER.debug("Not hashing %s because it starts with _", arg_name)
|
||||
continue
|
||||
|
||||
try:
|
||||
update_hash(
|
||||
arg_name,
|
||||
hasher=args_hasher,
|
||||
cache_type=cache_type,
|
||||
hash_source=func,
|
||||
)
|
||||
# we call update_hash twice here, first time for `arg_name`
|
||||
# without `hash_funcs`, and second time for `arg_value` with hash_funcs
|
||||
# to evaluate user defined `hash_funcs` only for computing `arg_value` hash.
|
||||
update_hash(
|
||||
arg_value,
|
||||
hasher=args_hasher,
|
||||
cache_type=cache_type,
|
||||
hash_funcs=hash_funcs,
|
||||
hash_source=func,
|
||||
)
|
||||
except UnhashableTypeError as exc:
|
||||
raise UnhashableParamError(cache_type, func, arg_name, arg_value, exc)
|
||||
|
||||
value_key = args_hasher.hexdigest()
|
||||
_LOGGER.debug("Cache key: %s", value_key)
|
||||
|
||||
return value_key
|
||||
|
||||
|
||||
def _make_function_key(cache_type: CacheType, func: FunctionType) -> str:
|
||||
"""Create the unique key for a function's cache.
|
||||
|
||||
A function's key is stable across reruns of the app, and changes when
|
||||
the function's source code changes.
|
||||
"""
|
||||
func_hasher = hashlib.new("md5", usedforsecurity=False)
|
||||
|
||||
# Include the function's __module__ and __qualname__ strings in the hash.
|
||||
# This means that two identical functions in different modules
|
||||
# will not share a hash; it also means that two identical *nested*
|
||||
# functions in the same module will not share a hash.
|
||||
update_hash(
|
||||
(func.__module__, func.__qualname__),
|
||||
hasher=func_hasher,
|
||||
cache_type=cache_type,
|
||||
hash_source=func,
|
||||
)
|
||||
|
||||
# Include the function's source code in its hash. If the source code can't
|
||||
# be retrieved, fall back to the function's bytecode instead.
|
||||
source_code: str | bytes
|
||||
try:
|
||||
source_code = inspect.getsource(func)
|
||||
except (OSError, TypeError) as ex:
|
||||
_LOGGER.debug(
|
||||
"Failed to retrieve function's source code when building its key; "
|
||||
"falling back to bytecode.",
|
||||
exc_info=ex,
|
||||
)
|
||||
source_code = func.__code__.co_code
|
||||
|
||||
update_hash(
|
||||
source_code, hasher=func_hasher, cache_type=cache_type, hash_source=func
|
||||
)
|
||||
|
||||
return func_hasher.hexdigest()
|
||||
|
||||
|
||||
def _get_positional_arg_name(func: FunctionType, arg_index: int) -> str | None:
|
||||
"""Return the name of a function's positional argument.
|
||||
|
||||
If arg_index is out of range, or refers to a parameter that is not a
|
||||
named positional argument (e.g. an *args, **kwargs, or keyword-only param),
|
||||
return None instead.
|
||||
"""
|
||||
if arg_index < 0:
|
||||
return None
|
||||
|
||||
params: list[inspect.Parameter] = list(inspect.signature(func).parameters.values())
|
||||
if arg_index >= len(params):
|
||||
return None
|
||||
|
||||
if params[arg_index].kind in (
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
):
|
||||
return params[arg_index].name
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,290 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union
|
||||
|
||||
import streamlit as st
|
||||
from streamlit import runtime, util
|
||||
from streamlit.deprecation_util import show_deprecation_warning
|
||||
from streamlit.runtime.caching.cache_errors import CacheReplayClosureError
|
||||
from streamlit.runtime.scriptrunner_utils.script_run_context import (
|
||||
in_cached_function,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
from types import FunctionType
|
||||
|
||||
from google.protobuf.message import Message
|
||||
|
||||
from streamlit.delta_generator import DeltaGenerator
|
||||
from streamlit.proto.Block_pb2 import Block
|
||||
from streamlit.runtime.caching.cache_type import CacheType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MediaMsgData:
|
||||
media: bytes | str
|
||||
mimetype: str
|
||||
media_id: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ElementMsgData:
|
||||
"""An element's message and related metadata for
|
||||
replaying that element's function call.
|
||||
|
||||
media_data is filled in iff this is a media element (image, audio, video).
|
||||
"""
|
||||
|
||||
delta_type: str
|
||||
message: Message
|
||||
id_of_dg_called_on: str
|
||||
returned_dgs_id: str
|
||||
media_data: list[MediaMsgData] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BlockMsgData:
|
||||
message: Block
|
||||
id_of_dg_called_on: str
|
||||
returned_dgs_id: str
|
||||
|
||||
|
||||
MsgData = Union[ElementMsgData, BlockMsgData]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedResult:
|
||||
"""The full results of calling a cache-decorated function, enough to
|
||||
replay the st functions called while executing it.
|
||||
"""
|
||||
|
||||
value: Any
|
||||
messages: list[MsgData]
|
||||
main_id: str
|
||||
sidebar_id: str
|
||||
|
||||
|
||||
"""
|
||||
Note [DeltaGenerator method invocation]
|
||||
There are two top level DG instances defined for all apps:
|
||||
`main`, which is for putting elements in the main part of the app
|
||||
`sidebar`, for the sidebar
|
||||
|
||||
There are 3 different ways an st function can be invoked:
|
||||
1. Implicitly on the main DG instance (plain `st.foo` calls)
|
||||
2. Implicitly in an active contextmanager block (`st.foo` within a `with st.container` context)
|
||||
3. Explicitly on a DG instance (`st.sidebar.foo`, `my_column_1.foo`)
|
||||
|
||||
To simplify replaying messages from a cached function result, we convert all of these
|
||||
to explicit invocations. How they get rewritten depends on if the invocation was
|
||||
implicit vs explicit, and if the target DG has been seen/produced during replay.
|
||||
|
||||
Implicit invocation on a known DG -> Explicit invocation on that DG
|
||||
Implicit invocation on an unknown DG -> Rewrite as explicit invocation on main
|
||||
with st.container():
|
||||
my_cache_decorated_function()
|
||||
|
||||
This is situation 2 above, and the DG is a block entirely outside our function call,
|
||||
so we interpret it as "put this element in the enclosing contextmanager block"
|
||||
(or main if there isn't one), which is achieved by invoking on main.
|
||||
Explicit invocation on a known DG -> No change needed
|
||||
Explicit invocation on an unknown DG -> Raise an error
|
||||
We have no way to identify the target DG, and it may not even be present in the
|
||||
current script run, so the least surprising thing to do is raise an error.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class CachedMessageReplayContext(threading.local):
|
||||
"""A utility for storing messages generated by `st` commands called inside
|
||||
a cached function.
|
||||
|
||||
Data is stored in a thread-local object, so it's safe to use an instance
|
||||
of this class across multiple threads.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_type: CacheType):
|
||||
self._cached_message_stack: list[list[MsgData]] = []
|
||||
self._seen_dg_stack: list[set[str]] = []
|
||||
self._most_recent_messages: list[MsgData] = []
|
||||
self._media_data: list[MediaMsgData] = []
|
||||
self._cache_type = cache_type
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def calling_cached_function(self, func: FunctionType) -> Iterator[None]:
|
||||
"""Context manager that should wrap the invocation of a cached function.
|
||||
It allows us to track any `st.foo` messages that are generated from inside the
|
||||
function for playback during cache retrieval.
|
||||
"""
|
||||
self._cached_message_stack.append([])
|
||||
self._seen_dg_stack.append(set())
|
||||
nested_call = False
|
||||
if in_cached_function.get():
|
||||
nested_call = True
|
||||
# If we're in a cached function. To disallow usage of widget-like element,
|
||||
# we need to set the in_cached_function to true for this cached function run
|
||||
# to prevent widget usage (triggers a warning).
|
||||
in_cached_function.set(True)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._most_recent_messages = self._cached_message_stack.pop()
|
||||
self._seen_dg_stack.pop()
|
||||
if not nested_call:
|
||||
# Reset the in_cached_function flag. But only if this
|
||||
# is not nested inside a cached function that disallows widget usage.
|
||||
in_cached_function.set(False)
|
||||
|
||||
def save_element_message(
|
||||
self,
|
||||
delta_type: str,
|
||||
element_proto: Message,
|
||||
invoked_dg_id: str,
|
||||
used_dg_id: str,
|
||||
returned_dg_id: str,
|
||||
) -> None:
|
||||
"""Record the element protobuf as having been produced during any currently
|
||||
executing cached functions, so they can be replayed any time the function's
|
||||
execution is skipped because they're in the cache.
|
||||
"""
|
||||
if not runtime.exists():
|
||||
return
|
||||
if len(self._cached_message_stack) >= 1:
|
||||
id_to_save = self.select_dg_to_save(invoked_dg_id, used_dg_id)
|
||||
|
||||
media_data = self._media_data
|
||||
|
||||
element_msg_data = ElementMsgData(
|
||||
delta_type,
|
||||
element_proto,
|
||||
id_to_save,
|
||||
returned_dg_id,
|
||||
media_data,
|
||||
)
|
||||
for msgs in self._cached_message_stack:
|
||||
msgs.append(element_msg_data)
|
||||
|
||||
# Reset instance state, now that it has been used for the
|
||||
# associated element.
|
||||
self._media_data = []
|
||||
|
||||
for s in self._seen_dg_stack:
|
||||
s.add(returned_dg_id)
|
||||
|
||||
def save_block_message(
|
||||
self,
|
||||
block_proto: Block,
|
||||
invoked_dg_id: str,
|
||||
used_dg_id: str,
|
||||
returned_dg_id: str,
|
||||
) -> None:
|
||||
id_to_save = self.select_dg_to_save(invoked_dg_id, used_dg_id)
|
||||
for msgs in self._cached_message_stack:
|
||||
msgs.append(BlockMsgData(block_proto, id_to_save, returned_dg_id))
|
||||
for s in self._seen_dg_stack:
|
||||
s.add(returned_dg_id)
|
||||
|
||||
def select_dg_to_save(self, invoked_id: str, acting_on_id: str) -> str:
|
||||
"""Select the id of the DG that this message should be invoked on
|
||||
during message replay.
|
||||
|
||||
See Note [DeltaGenerator method invocation]
|
||||
|
||||
invoked_id is the DG the st function was called on, usually `st._main`.
|
||||
acting_on_id is the DG the st function ultimately runs on, which may be different
|
||||
if the invoked DG delegated to another one because it was in a `with` block.
|
||||
"""
|
||||
if len(self._seen_dg_stack) > 0 and acting_on_id in self._seen_dg_stack[-1]:
|
||||
return acting_on_id
|
||||
else:
|
||||
return invoked_id
|
||||
|
||||
def save_image_data(
|
||||
self, image_data: bytes | str, mimetype: str, image_id: str
|
||||
) -> None:
|
||||
self._media_data.append(MediaMsgData(image_data, mimetype, image_id))
|
||||
|
||||
|
||||
def replay_cached_messages(
|
||||
result: CachedResult, cache_type: CacheType, cached_func: FunctionType
|
||||
) -> None:
|
||||
"""Replay the st element function calls that happened when executing a
|
||||
cache-decorated function.
|
||||
|
||||
When a cache function is executed, we record the element and block messages
|
||||
produced, and use those to reproduce the DeltaGenerator calls, so the elements
|
||||
will appear in the web app even when execution of the function is skipped
|
||||
because the result was cached.
|
||||
|
||||
To make this work, for each st function call we record an identifier for the
|
||||
DG it was effectively called on (see Note [DeltaGenerator method invocation]).
|
||||
We also record the identifier for each DG returned by an st function call, if
|
||||
it returns one. Then, for each recorded message, we get the current DG instance
|
||||
corresponding to the DG the message was originally called on, and enqueue the
|
||||
message using that, recording any new DGs produced in case a later st function
|
||||
call is on one of them.
|
||||
"""
|
||||
from streamlit.delta_generator import DeltaGenerator
|
||||
|
||||
# Maps originally recorded dg ids to this script run's version of that dg
|
||||
returned_dgs: dict[str, DeltaGenerator] = {
|
||||
result.main_id: st._main,
|
||||
result.sidebar_id: st.sidebar,
|
||||
}
|
||||
try:
|
||||
for msg in result.messages:
|
||||
if isinstance(msg, ElementMsgData):
|
||||
if msg.media_data is not None:
|
||||
for data in msg.media_data:
|
||||
runtime.get_instance().media_file_mgr.add(
|
||||
data.media, data.mimetype, data.media_id
|
||||
)
|
||||
dg = returned_dgs[msg.id_of_dg_called_on]
|
||||
maybe_dg = dg._enqueue(msg.delta_type, msg.message)
|
||||
if isinstance(maybe_dg, DeltaGenerator):
|
||||
returned_dgs[msg.returned_dgs_id] = maybe_dg
|
||||
elif isinstance(msg, BlockMsgData):
|
||||
dg = returned_dgs[msg.id_of_dg_called_on]
|
||||
new_dg = dg._block(msg.message)
|
||||
returned_dgs[msg.returned_dgs_id] = new_dg
|
||||
except KeyError as ex:
|
||||
raise CacheReplayClosureError(cache_type, cached_func) from ex
|
||||
|
||||
|
||||
def show_widget_replay_deprecation(
|
||||
decorator: Literal["cache_data", "cache_resource"],
|
||||
) -> None:
|
||||
show_deprecation_warning(
|
||||
"The cached widget replay feature was removed in 1.38. The "
|
||||
"`experimental_allow_widgets` parameter will also be removed "
|
||||
"in a future release. Please remove the `experimental_allow_widgets` parameter "
|
||||
f"from the `@st.{decorator}` decorator and move all widget commands outside of "
|
||||
"cached functions.\n\nTo speed up your app, we recommend moving your widgets "
|
||||
"into fragments. Find out more about fragments in "
|
||||
"[our docs](https://docs.streamlit.io/develop/api-reference/execution-flow/st.fragment). "
|
||||
"\n\nIf you have a specific use-case that requires the "
|
||||
"`experimental_allow_widgets` functionality, please tell us via an "
|
||||
"[issue on Github](https://github.com/streamlit/streamlit/issues)."
|
||||
)
|
||||
@@ -0,0 +1,637 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Hashing for st.cache_data and st.cache_resource."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import collections.abc
|
||||
import dataclasses
|
||||
import datetime
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import io
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import uuid
|
||||
import weakref
|
||||
from enum import Enum
|
||||
from re import Pattern
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Callable, Final, Union, cast
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from streamlit import logger, type_util, util
|
||||
from streamlit.errors import StreamlitAPIException
|
||||
from streamlit.runtime.caching.cache_errors import UnhashableTypeError
|
||||
from streamlit.runtime.caching.cache_type import CacheType
|
||||
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
||||
|
||||
_LOGGER: Final = logger.get_logger(__name__)
|
||||
|
||||
# If a dataframe has more than this many rows, we consider it large and hash a sample.
|
||||
_PANDAS_ROWS_LARGE: Final = 50_000
|
||||
_PANDAS_SAMPLE_SIZE: Final = 10_000
|
||||
|
||||
# Similar to dataframes, we also sample large numpy arrays.
|
||||
_NP_SIZE_LARGE: Final = 500_000
|
||||
_NP_SAMPLE_SIZE: Final = 100_000
|
||||
|
||||
HashFuncsDict: TypeAlias = dict[Union[str, type[Any]], Callable[[Any], Any]]
|
||||
|
||||
# Arbitrary item to denote where we found a cycle in a hashed object.
|
||||
# This allows us to hash self-referencing lists, dictionaries, etc.
|
||||
_CYCLE_PLACEHOLDER: Final = (
|
||||
b"streamlit-57R34ML17-hesamagicalponyflyingthroughthesky-CYCLE"
|
||||
)
|
||||
|
||||
|
||||
class UserHashError(StreamlitAPIException):
|
||||
def __init__(
|
||||
self,
|
||||
orig_exc,
|
||||
object_to_hash,
|
||||
hash_func,
|
||||
cache_type: CacheType | None = None,
|
||||
):
|
||||
self.alternate_name = type(orig_exc).__name__
|
||||
self.hash_func = hash_func
|
||||
self.cache_type = cache_type
|
||||
|
||||
msg = self._get_message_from_func(orig_exc, object_to_hash)
|
||||
|
||||
super().__init__(msg)
|
||||
self.with_traceback(orig_exc.__traceback__)
|
||||
|
||||
def _get_message_from_func(self, orig_exc, cached_func):
|
||||
args = self._get_error_message_args(orig_exc, cached_func)
|
||||
|
||||
return (
|
||||
"""
|
||||
%(orig_exception_desc)s
|
||||
|
||||
This error is likely due to a bug in %(hash_func_name)s, which is a
|
||||
user-defined hash function that was passed into the `%(cache_primitive)s` decorator of
|
||||
%(object_desc)s.
|
||||
|
||||
%(hash_func_name)s failed when hashing an object of type
|
||||
`%(failed_obj_type_str)s`. If you don't know where that object is coming from,
|
||||
try looking at the hash chain below for an object that you do recognize, then
|
||||
pass that to `hash_funcs` instead:
|
||||
|
||||
```
|
||||
%(hash_stack)s
|
||||
```
|
||||
|
||||
If you think this is actually a Streamlit bug, please
|
||||
[file a bug report here](https://github.com/streamlit/streamlit/issues/new/choose).
|
||||
"""
|
||||
% args
|
||||
).strip("\n")
|
||||
|
||||
def _get_error_message_args(
|
||||
self,
|
||||
orig_exc: BaseException,
|
||||
failed_obj: Any,
|
||||
) -> dict[str, Any]:
|
||||
hash_source = hash_stacks.current.hash_source
|
||||
|
||||
failed_obj_type_str = type_util.get_fqn_type(failed_obj)
|
||||
|
||||
if hash_source is None:
|
||||
object_desc = "something"
|
||||
else:
|
||||
if hasattr(hash_source, "__name__"):
|
||||
object_desc = f"`{hash_source.__name__}()`"
|
||||
else:
|
||||
object_desc = "a function"
|
||||
|
||||
decorator_name = ""
|
||||
if self.cache_type is CacheType.RESOURCE:
|
||||
decorator_name = "@st.cache_resource"
|
||||
elif self.cache_type is CacheType.DATA:
|
||||
decorator_name = "@st.cache_data"
|
||||
|
||||
if hasattr(self.hash_func, "__name__"):
|
||||
hash_func_name = f"`{self.hash_func.__name__}()`"
|
||||
else:
|
||||
hash_func_name = "a function"
|
||||
|
||||
return {
|
||||
"orig_exception_desc": str(orig_exc),
|
||||
"failed_obj_type_str": failed_obj_type_str,
|
||||
"hash_stack": hash_stacks.current.pretty_print(),
|
||||
"object_desc": object_desc,
|
||||
"cache_primitive": decorator_name,
|
||||
"hash_func_name": hash_func_name,
|
||||
}
|
||||
|
||||
|
||||
def update_hash(
|
||||
val: Any,
|
||||
hasher,
|
||||
cache_type: CacheType,
|
||||
hash_source: Callable[..., Any] | None = None,
|
||||
hash_funcs: HashFuncsDict | None = None,
|
||||
) -> None:
|
||||
"""Updates a hashlib hasher with the hash of val.
|
||||
|
||||
This is the main entrypoint to hashing.py.
|
||||
"""
|
||||
|
||||
hash_stacks.current.hash_source = hash_source
|
||||
|
||||
ch = _CacheFuncHasher(cache_type, hash_funcs)
|
||||
ch.update(hasher, val)
|
||||
|
||||
|
||||
class _HashStack:
|
||||
"""Stack of what has been hashed, for debug and circular reference detection.
|
||||
|
||||
This internally keeps 1 stack per thread.
|
||||
|
||||
Internally, this stores the ID of pushed objects rather than the objects
|
||||
themselves because otherwise the "in" operator inside __contains__ would
|
||||
fail for objects that don't return a boolean for "==" operator. For
|
||||
example, arr == 10 where arr is a NumPy array returns another NumPy array.
|
||||
This causes the "in" to crash since it expects a boolean.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._stack: collections.OrderedDict[int, list[Any]] = collections.OrderedDict()
|
||||
# A function that we decorate with streamlit cache
|
||||
# primitive (st.cache_data or st.cache_resource).
|
||||
self.hash_source: Callable[..., Any] | None = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
def push(self, val: Any):
|
||||
self._stack[id(val)] = val
|
||||
|
||||
def pop(self):
|
||||
self._stack.popitem()
|
||||
|
||||
def __contains__(self, val: Any):
|
||||
return id(val) in self._stack
|
||||
|
||||
def pretty_print(self) -> str:
|
||||
def to_str(v: Any) -> str:
|
||||
try:
|
||||
return f"Object of type {type_util.get_fqn_type(v)}: {str(v)}"
|
||||
except Exception:
|
||||
return "<Unable to convert item to string>"
|
||||
|
||||
return "\n".join(to_str(x) for x in reversed(self._stack.values()))
|
||||
|
||||
|
||||
class _HashStacks:
|
||||
"""Stacks of what has been hashed, with at most 1 stack per thread."""
|
||||
|
||||
def __init__(self):
|
||||
self._stacks: weakref.WeakKeyDictionary[threading.Thread, _HashStack] = (
|
||||
weakref.WeakKeyDictionary()
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
@property
|
||||
def current(self) -> _HashStack:
|
||||
current_thread = threading.current_thread()
|
||||
|
||||
stack = self._stacks.get(current_thread, None)
|
||||
|
||||
if stack is None:
|
||||
stack = _HashStack()
|
||||
self._stacks[current_thread] = stack
|
||||
|
||||
return stack
|
||||
|
||||
|
||||
hash_stacks = _HashStacks()
|
||||
|
||||
|
||||
def _int_to_bytes(i: int) -> bytes:
|
||||
num_bytes = (i.bit_length() + 8) // 8
|
||||
return i.to_bytes(num_bytes, "little", signed=True)
|
||||
|
||||
|
||||
def _float_to_bytes(f: float) -> bytes:
|
||||
# Lazy-load for performance reasons.
|
||||
import struct
|
||||
|
||||
# Floats are 64bit in Python, so we need to use the "d" format.
|
||||
return struct.pack("<d", f)
|
||||
|
||||
|
||||
def _key(obj: Any | None) -> Any:
|
||||
"""Return key for memoization."""
|
||||
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
def is_simple(obj):
|
||||
return (
|
||||
isinstance(obj, bytes)
|
||||
or isinstance(obj, bytearray)
|
||||
or isinstance(obj, str)
|
||||
or isinstance(obj, float)
|
||||
or isinstance(obj, int)
|
||||
or isinstance(obj, bool)
|
||||
or isinstance(obj, uuid.UUID)
|
||||
or obj is None
|
||||
)
|
||||
|
||||
if is_simple(obj):
|
||||
return obj
|
||||
|
||||
if isinstance(obj, tuple):
|
||||
if all(map(is_simple, obj)):
|
||||
return obj
|
||||
|
||||
if isinstance(obj, list):
|
||||
if all(map(is_simple, obj)):
|
||||
return ("__l", tuple(obj))
|
||||
|
||||
if inspect.isbuiltin(obj) or inspect.isroutine(obj) or inspect.iscode(obj):
|
||||
return id(obj)
|
||||
|
||||
return NoResult
|
||||
|
||||
|
||||
class _CacheFuncHasher:
|
||||
"""A hasher that can hash objects with cycles."""
|
||||
|
||||
def __init__(self, cache_type: CacheType, hash_funcs: HashFuncsDict | None = None):
|
||||
# Can't use types as the keys in the internal _hash_funcs because
|
||||
# we always remove user-written modules from memory when rerunning a
|
||||
# script in order to reload it and grab the latest code changes.
|
||||
# (See LocalSourcesWatcher.py:on_file_changed) This causes
|
||||
# the type object to refer to different underlying class instances each run,
|
||||
# so type-based comparisons fail. To solve this, we use the types converted
|
||||
# to fully-qualified strings as keys in our internal dict.
|
||||
self._hash_funcs: HashFuncsDict
|
||||
if hash_funcs:
|
||||
self._hash_funcs = {
|
||||
k if isinstance(k, str) else type_util.get_fqn(k): v
|
||||
for k, v in hash_funcs.items()
|
||||
}
|
||||
else:
|
||||
self._hash_funcs = {}
|
||||
self._hashes: dict[Any, bytes] = {}
|
||||
|
||||
# The number of the bytes in the hash.
|
||||
self.size = 0
|
||||
|
||||
self.cache_type = cache_type
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
def to_bytes(self, obj: Any) -> bytes:
|
||||
"""Add memoization to _to_bytes and protect against cycles in data structures."""
|
||||
tname = type(obj).__qualname__.encode()
|
||||
key = (tname, _key(obj))
|
||||
|
||||
# Memoize if possible.
|
||||
if key[1] is not NoResult:
|
||||
if key in self._hashes:
|
||||
return self._hashes[key]
|
||||
|
||||
# Break recursive cycles.
|
||||
if obj in hash_stacks.current:
|
||||
return _CYCLE_PLACEHOLDER
|
||||
|
||||
hash_stacks.current.push(obj)
|
||||
|
||||
try:
|
||||
# Hash the input
|
||||
b = b"%s:%s" % (tname, self._to_bytes(obj))
|
||||
|
||||
# Hmmm... It's possible that the size calculation is wrong. When we
|
||||
# call to_bytes inside _to_bytes things get double-counted.
|
||||
self.size += sys.getsizeof(b)
|
||||
|
||||
if key[1] is not NoResult:
|
||||
self._hashes[key] = b
|
||||
|
||||
finally:
|
||||
# In case an UnhashableTypeError (or other) error is thrown, clean up the
|
||||
# stack so we don't get false positives in future hashing calls
|
||||
hash_stacks.current.pop()
|
||||
|
||||
return b
|
||||
|
||||
def update(self, hasher, obj: Any) -> None:
|
||||
"""Update the provided hasher with the hash of an object."""
|
||||
b = self.to_bytes(obj)
|
||||
hasher.update(b)
|
||||
|
||||
def _to_bytes(self, obj: Any) -> bytes:
|
||||
"""Hash objects to bytes, including code with dependencies.
|
||||
|
||||
Python's built in `hash` does not produce consistent results across
|
||||
runs.
|
||||
"""
|
||||
|
||||
h = hashlib.new("md5", usedforsecurity=False)
|
||||
|
||||
if type_util.is_type(obj, "unittest.mock.Mock") or type_util.is_type(
|
||||
obj, "unittest.mock.MagicMock"
|
||||
):
|
||||
# Mock objects can appear to be infinitely
|
||||
# deep, so we don't try to hash them at all.
|
||||
return self.to_bytes(id(obj))
|
||||
|
||||
elif isinstance(obj, bytes) or isinstance(obj, bytearray):
|
||||
return obj
|
||||
|
||||
elif type_util.get_fqn_type(obj) in self._hash_funcs:
|
||||
# Escape hatch for unsupported objects
|
||||
hash_func = self._hash_funcs[type_util.get_fqn_type(obj)]
|
||||
try:
|
||||
output = hash_func(obj)
|
||||
except Exception as ex:
|
||||
raise UserHashError(
|
||||
ex, obj, hash_func=hash_func, cache_type=self.cache_type
|
||||
) from ex
|
||||
return self.to_bytes(output)
|
||||
|
||||
elif isinstance(obj, str):
|
||||
return obj.encode()
|
||||
|
||||
elif isinstance(obj, float):
|
||||
return _float_to_bytes(obj)
|
||||
|
||||
elif isinstance(obj, int):
|
||||
return _int_to_bytes(obj)
|
||||
|
||||
elif isinstance(obj, uuid.UUID):
|
||||
return obj.bytes
|
||||
|
||||
elif isinstance(obj, datetime.datetime):
|
||||
return obj.isoformat().encode()
|
||||
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
for item in obj:
|
||||
self.update(h, item)
|
||||
return h.digest()
|
||||
|
||||
elif isinstance(obj, dict):
|
||||
for item in obj.items():
|
||||
self.update(h, item)
|
||||
return h.digest()
|
||||
|
||||
elif obj is None:
|
||||
return b"0"
|
||||
|
||||
elif obj is True:
|
||||
return b"1"
|
||||
|
||||
elif obj is False:
|
||||
return b"0"
|
||||
|
||||
elif not isinstance(obj, type) and dataclasses.is_dataclass(obj):
|
||||
return self.to_bytes(dataclasses.asdict(obj))
|
||||
elif isinstance(obj, Enum):
|
||||
return str(obj).encode()
|
||||
|
||||
elif type_util.is_type(obj, "pandas.core.series.Series"):
|
||||
import pandas as pd
|
||||
|
||||
obj = cast("pd.Series", obj)
|
||||
self.update(h, obj.size)
|
||||
self.update(h, obj.dtype.name)
|
||||
|
||||
if len(obj) >= _PANDAS_ROWS_LARGE:
|
||||
obj = obj.sample(n=_PANDAS_SAMPLE_SIZE, random_state=0)
|
||||
|
||||
try:
|
||||
self.update(h, pd.util.hash_pandas_object(obj).to_numpy().tobytes())
|
||||
return h.digest()
|
||||
except TypeError:
|
||||
_LOGGER.warning(
|
||||
"Pandas Series hash failed. Falling back to pickling the object.",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Use pickle if pandas cannot hash the object for example if
|
||||
# it contains unhashable objects.
|
||||
return b"%s" % pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
elif type_util.is_type(obj, "pandas.core.frame.DataFrame"):
|
||||
import pandas as pd
|
||||
|
||||
obj = cast("pd.DataFrame", obj)
|
||||
self.update(h, obj.shape)
|
||||
|
||||
if len(obj) >= _PANDAS_ROWS_LARGE:
|
||||
obj = obj.sample(n=_PANDAS_SAMPLE_SIZE, random_state=0)
|
||||
try:
|
||||
column_hash_bytes = self.to_bytes(
|
||||
pd.util.hash_pandas_object(obj.dtypes)
|
||||
)
|
||||
self.update(h, column_hash_bytes)
|
||||
values_hash_bytes = self.to_bytes(pd.util.hash_pandas_object(obj))
|
||||
self.update(h, values_hash_bytes)
|
||||
return h.digest()
|
||||
except TypeError:
|
||||
_LOGGER.warning(
|
||||
"Pandas DataFrame hash failed. Falling back to pickling the object.",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Use pickle if pandas cannot hash the object for example if
|
||||
# it contains unhashable objects.
|
||||
return b"%s" % pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
elif type_util.is_type(obj, "polars.series.series.Series"):
|
||||
import polars as pl # type: ignore[import-not-found]
|
||||
|
||||
obj = cast("pl.Series", obj)
|
||||
self.update(h, str(obj.dtype).encode())
|
||||
self.update(h, obj.shape)
|
||||
|
||||
if len(obj) >= _PANDAS_ROWS_LARGE:
|
||||
obj = obj.sample(n=_PANDAS_SAMPLE_SIZE, seed=0)
|
||||
|
||||
try:
|
||||
self.update(h, obj.hash(seed=0).to_arrow().to_string().encode())
|
||||
return h.digest()
|
||||
except TypeError:
|
||||
_LOGGER.warning(
|
||||
"Polars Series hash failed. Falling back to pickling the object.",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Use pickle if polars cannot hash the object for example if
|
||||
# it contains unhashable objects.
|
||||
return b"%s" % pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
|
||||
elif type_util.is_type(obj, "polars.dataframe.frame.DataFrame"):
|
||||
import polars as pl # noqa: TC002
|
||||
|
||||
obj = cast("pl.DataFrame", obj)
|
||||
self.update(h, obj.shape)
|
||||
|
||||
if len(obj) >= _PANDAS_ROWS_LARGE:
|
||||
obj = obj.sample(n=_PANDAS_SAMPLE_SIZE, seed=0)
|
||||
try:
|
||||
for c, t in obj.schema.items():
|
||||
self.update(h, c.encode())
|
||||
self.update(h, str(t).encode())
|
||||
|
||||
values_hash_bytes = (
|
||||
obj.hash_rows(seed=0).hash(seed=0).to_arrow().to_string().encode()
|
||||
)
|
||||
|
||||
self.update(h, values_hash_bytes)
|
||||
return h.digest()
|
||||
except TypeError:
|
||||
_LOGGER.warning(
|
||||
"Polars DataFrame hash failed. Falling back to pickling the object.",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Use pickle if polars cannot hash the object for example if
|
||||
# it contains unhashable objects.
|
||||
return b"%s" % pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
|
||||
elif type_util.is_type(obj, "numpy.ndarray"):
|
||||
import numpy as np
|
||||
|
||||
# write cast type as string to make it work with our Python 3.8 tests
|
||||
# - can be removed once we sunset support for Python 3.8
|
||||
obj = cast("np.ndarray[Any, Any]", obj)
|
||||
self.update(h, obj.shape)
|
||||
self.update(h, str(obj.dtype))
|
||||
|
||||
if obj.size >= _NP_SIZE_LARGE:
|
||||
import numpy as np
|
||||
|
||||
state = np.random.RandomState(0)
|
||||
obj = state.choice(obj.flat, size=_NP_SAMPLE_SIZE)
|
||||
|
||||
self.update(h, obj.tobytes())
|
||||
return h.digest()
|
||||
elif type_util.is_type(obj, "PIL.Image.Image"):
|
||||
import numpy as np
|
||||
from PIL.Image import Image # noqa: TC002
|
||||
|
||||
obj = cast("Image", obj)
|
||||
|
||||
# we don't just hash the results of obj.tobytes() because we want to use
|
||||
# the sampling logic for numpy data
|
||||
np_array = np.frombuffer(obj.tobytes(), dtype="uint8")
|
||||
return self.to_bytes(np_array)
|
||||
|
||||
elif inspect.isbuiltin(obj):
|
||||
return bytes(obj.__name__.encode())
|
||||
|
||||
elif isinstance(obj, MappingProxyType) or isinstance(
|
||||
obj, collections.abc.ItemsView
|
||||
):
|
||||
return self.to_bytes(dict(obj))
|
||||
|
||||
elif type_util.is_type(obj, "builtins.getset_descriptor"):
|
||||
return bytes(obj.__qualname__.encode())
|
||||
|
||||
elif isinstance(obj, UploadedFile):
|
||||
# UploadedFile is a BytesIO (thus IOBase) but has a name.
|
||||
# It does not have a timestamp so this must come before
|
||||
# temporary files
|
||||
self.update(h, obj.name)
|
||||
self.update(h, obj.tell())
|
||||
self.update(h, obj.getvalue())
|
||||
return h.digest()
|
||||
|
||||
elif hasattr(obj, "name") and (
|
||||
isinstance(obj, io.IOBase)
|
||||
# Handle temporary files used during testing
|
||||
or isinstance(obj, tempfile._TemporaryFileWrapper)
|
||||
):
|
||||
# Hash files as name + last modification date + offset.
|
||||
# NB: we're using hasattr("name") to differentiate between
|
||||
# on-disk and in-memory StringIO/BytesIO file representations.
|
||||
# That means that this condition must come *before* the next
|
||||
# condition, which just checks for StringIO/BytesIO.
|
||||
obj_name = getattr(obj, "name", "wonthappen") # Just to appease MyPy.
|
||||
self.update(h, obj_name)
|
||||
self.update(h, os.path.getmtime(obj_name))
|
||||
self.update(h, obj.tell())
|
||||
return h.digest()
|
||||
|
||||
elif isinstance(obj, Pattern):
|
||||
return self.to_bytes([obj.pattern, obj.flags])
|
||||
|
||||
elif isinstance(obj, io.StringIO) or isinstance(obj, io.BytesIO):
|
||||
# Hash in-memory StringIO/BytesIO by their full contents
|
||||
# and seek position.
|
||||
self.update(h, obj.tell())
|
||||
self.update(h, obj.getvalue())
|
||||
return h.digest()
|
||||
|
||||
elif type_util.is_type(obj, "numpy.ufunc"):
|
||||
# For numpy.remainder, this returns remainder.
|
||||
return bytes(obj.__name__.encode())
|
||||
|
||||
elif inspect.ismodule(obj):
|
||||
# TODO: Figure out how to best show this kind of warning to the
|
||||
# user. In the meantime, show nothing. This scenario is too common,
|
||||
# so the current warning is quite annoying...
|
||||
# st.warning(('Streamlit does not support hashing modules. '
|
||||
# 'We did not hash `%s`.') % obj.__name__)
|
||||
# TODO: Hash more than just the name for internal modules.
|
||||
return self.to_bytes(obj.__name__)
|
||||
|
||||
elif inspect.isclass(obj):
|
||||
# TODO: Figure out how to best show this kind of warning to the
|
||||
# user. In the meantime, show nothing. This scenario is too common,
|
||||
# (e.g. in every "except" statement) so the current warning is
|
||||
# quite annoying...
|
||||
# st.warning(('Streamlit does not support hashing classes. '
|
||||
# 'We did not hash `%s`.') % obj.__name__)
|
||||
# TODO: Hash more than just the name of classes.
|
||||
return self.to_bytes(obj.__name__)
|
||||
|
||||
elif isinstance(obj, functools.partial):
|
||||
# The return value of functools.partial is not a plain function:
|
||||
# it's a callable object that remembers the original function plus
|
||||
# the values you pickled into it. So here we need to special-case it.
|
||||
self.update(h, obj.args)
|
||||
self.update(h, obj.func)
|
||||
self.update(h, obj.keywords)
|
||||
return h.digest()
|
||||
|
||||
else:
|
||||
# As a last resort, hash the output of the object's __reduce__ method
|
||||
try:
|
||||
reduce_data = obj.__reduce__()
|
||||
except Exception as ex:
|
||||
raise UnhashableTypeError() from ex
|
||||
|
||||
for item in reduce_data:
|
||||
self.update(h, item)
|
||||
return h.digest()
|
||||
|
||||
|
||||
class NoResult:
|
||||
"""Placeholder class for return values when None is meaningful."""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,169 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A library of caching utilities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar
|
||||
|
||||
from streamlit import deprecation_util
|
||||
from streamlit.runtime.caching import CACHE_DOCS_URL
|
||||
from streamlit.runtime.metrics_util import gather_metrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.runtime.caching.hashing import HashFuncsDict
|
||||
|
||||
# Type-annotate the decorator function.
|
||||
# (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories)
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
@gather_metrics("cache")
|
||||
def cache(
|
||||
func: F | None = None,
|
||||
persist: bool = False,
|
||||
allow_output_mutation: bool = False,
|
||||
show_spinner: bool = True,
|
||||
suppress_st_warning: bool = False,
|
||||
hash_funcs: HashFuncsDict | None = None,
|
||||
max_entries: int | None = None,
|
||||
ttl: float | None = None,
|
||||
):
|
||||
"""Legacy caching decorator (deprecated).
|
||||
|
||||
Legacy caching with ``st.cache`` has been removed from Streamlit. This is
|
||||
now an alias for ``st.cache_data`` and ``st.cache_resource``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
The function to cache. Streamlit hashes the function's source code.
|
||||
|
||||
persist : bool
|
||||
Whether to persist the cache on disk.
|
||||
|
||||
allow_output_mutation : bool
|
||||
Whether to use ``st.cache_data`` or ``st.cache_resource``. If this is
|
||||
``False`` (default), the arguments are passed to ``st.cache_data``. If
|
||||
this is ``True``, the arguments are passed to ``st.cache_resource``.
|
||||
|
||||
show_spinner : bool
|
||||
Enable the spinner. Default is ``True`` to show a spinner when there is
|
||||
a "cache miss" and the cached data is being created.
|
||||
|
||||
suppress_st_warning : bool
|
||||
This is not used.
|
||||
|
||||
hash_funcs : dict or None
|
||||
Mapping of types or fully qualified names to hash functions. This is used to
|
||||
override the behavior of the hasher inside Streamlit's caching mechanism: when
|
||||
the hasher encounters an object, it will first check to see if its type matches
|
||||
a key in this dict and, if so, will use the provided function to generate a hash
|
||||
for it. See below for an example of how this can be used.
|
||||
|
||||
max_entries : int or None
|
||||
The maximum number of entries to keep in the cache, or ``None``
|
||||
for an unbounded cache. (When a new entry is added to a full cache,
|
||||
the oldest cached entry will be removed.) The default is ``None``.
|
||||
|
||||
ttl : float or None
|
||||
The maximum number of seconds to keep an entry in the cache, or
|
||||
None if cache entries should not expire. The default is None.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache
|
||||
... def fetch_and_clean_data(url):
|
||||
... # Fetch data from URL here, and then clean it up.
|
||||
... return data
|
||||
>>>
|
||||
>>> d1 = fetch_and_clean_data(DATA_URL_1)
|
||||
>>> # Actually executes the function, since this is the first time it was
|
||||
>>> # encountered.
|
||||
>>>
|
||||
>>> d2 = fetch_and_clean_data(DATA_URL_1)
|
||||
>>> # Does not execute the function. Instead, returns its previously computed
|
||||
>>> # value. This means that now the data in d1 is the same as in d2.
|
||||
>>>
|
||||
>>> d3 = fetch_and_clean_data(DATA_URL_2)
|
||||
>>> # This is a different URL, so the function executes.
|
||||
|
||||
To set the ``persist`` parameter, use this command as follows:
|
||||
|
||||
>>> @st.cache(persist=True)
|
||||
... def fetch_and_clean_data(url):
|
||||
... # Fetch data from URL here, and then clean it up.
|
||||
... return data
|
||||
|
||||
To disable hashing return values, set the ``allow_output_mutation`` parameter to
|
||||
``True``:
|
||||
|
||||
>>> @st.cache(allow_output_mutation=True)
|
||||
... def fetch_and_clean_data(url):
|
||||
... # Fetch data from URL here, and then clean it up.
|
||||
... return data
|
||||
|
||||
|
||||
To override the default hashing behavior, pass a custom hash function.
|
||||
You can do that by mapping a type (e.g. ``MongoClient``) to a hash function (``id``)
|
||||
like this:
|
||||
|
||||
>>> @st.cache(hash_funcs={MongoClient: id})
|
||||
... def connect_to_database(url):
|
||||
... return MongoClient(url)
|
||||
|
||||
Alternatively, you can map the type's fully-qualified name
|
||||
(e.g. ``"pymongo.mongo_client.MongoClient"``) to the hash function instead:
|
||||
|
||||
>>> @st.cache(hash_funcs={"pymongo.mongo_client.MongoClient": id})
|
||||
... def connect_to_database(url):
|
||||
... return MongoClient(url)
|
||||
|
||||
"""
|
||||
import streamlit as st
|
||||
|
||||
deprecation_util.show_deprecation_warning(
|
||||
f"""
|
||||
`st.cache` is deprecated and will be removed soon. Please use one of Streamlit's new
|
||||
caching commands, `st.cache_data` or `st.cache_resource`. More information
|
||||
[in our docs]({CACHE_DOCS_URL}).
|
||||
|
||||
**Note**: The behavior of `st.cache` was updated in Streamlit 1.36 to the new caching
|
||||
logic used by `st.cache_data` and `st.cache_resource`. This might lead to some problems
|
||||
or unexpected behavior in certain edge cases.
|
||||
"""
|
||||
)
|
||||
|
||||
# suppress_st_warning is unused since its not supported by the new caching commands
|
||||
|
||||
if allow_output_mutation:
|
||||
return st.cache_resource( # type: ignore
|
||||
func,
|
||||
show_spinner=show_spinner,
|
||||
hash_funcs=hash_funcs,
|
||||
max_entries=max_entries,
|
||||
ttl=ttl,
|
||||
)
|
||||
|
||||
return st.cache_data( # type: ignore
|
||||
func,
|
||||
persist=persist,
|
||||
show_spinner=show_spinner,
|
||||
hash_funcs=hash_funcs,
|
||||
max_entries=max_entries,
|
||||
ttl=ttl,
|
||||
)
|
||||
@@ -0,0 +1,29 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
CacheStorage,
|
||||
CacheStorageContext,
|
||||
CacheStorageError,
|
||||
CacheStorageKeyNotFoundError,
|
||||
CacheStorageManager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CacheStorage",
|
||||
"CacheStorageContext",
|
||||
"CacheStorageError",
|
||||
"CacheStorageKeyNotFoundError",
|
||||
"CacheStorageManager",
|
||||
]
|
||||
@@ -0,0 +1,239 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Declares the CacheStorageContext dataclass, which contains parameter information for
|
||||
each function decorated by `@st.cache_data` (for example: ttl, max_entries etc.).
|
||||
|
||||
Declares the CacheStorageManager protocol, which implementations are used
|
||||
to create CacheStorage instances and to optionally clear all cache storages,
|
||||
that were created by this manager, and to check if the context is valid for the storage.
|
||||
|
||||
Declares the CacheStorage protocol, which implementations are used to store cached
|
||||
values for a single `@st.cache_data` decorated function serialized as bytes.
|
||||
|
||||
How these classes work together
|
||||
-------------------------------
|
||||
- CacheStorageContext : this is a dataclass that contains the parameters from
|
||||
`@st.cache_data` that are passed to the CacheStorageManager.create() method.
|
||||
|
||||
- CacheStorageManager : each instance of this is able to create CacheStorage
|
||||
instances, and optionally to clear data of all cache storages.
|
||||
|
||||
- CacheStorage : each instance of this is able to get, set, delete, and clear
|
||||
entries for a single `@st.cache_data` decorated function.
|
||||
|
||||
┌───────────────────────────────┐
|
||||
│ │
|
||||
│ CacheStorageManager │
|
||||
│ │
|
||||
│ - clear_all(optional) │
|
||||
│ - check_context │
|
||||
│ │
|
||||
└──┬────────────────────────────┘
|
||||
│
|
||||
│ ┌──────────────────────┐
|
||||
│ │ CacheStorage │
|
||||
│ create(context)│ │
|
||||
└────────────────► - get │
|
||||
│ - set │
|
||||
│ - delete │
|
||||
│ - close (optional)│
|
||||
│ - clear │
|
||||
└──────────────────────┘
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Protocol
|
||||
|
||||
|
||||
class CacheStorageError(Exception):
|
||||
"""Base exception raised by the cache storage."""
|
||||
|
||||
|
||||
class CacheStorageKeyNotFoundError(CacheStorageError):
|
||||
"""Raised when the key is not found in the cache storage."""
|
||||
|
||||
|
||||
class InvalidCacheStorageContext(CacheStorageError):
|
||||
"""Raised if the cache storage manager is not able to work with
|
||||
provided CacheStorageContext.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CacheStorageContext:
|
||||
"""Context passed to the cache storage during initialization
|
||||
This is the normalized parameters that are passed to CacheStorageManager.create()
|
||||
method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
function_key: str
|
||||
A hash computed based on function name and source code decorated
|
||||
by `@st.cache_data`
|
||||
|
||||
function_display_name: str
|
||||
The display name of the function that is decorated by `@st.cache_data`
|
||||
|
||||
ttl_seconds : float or None
|
||||
The time-to-live for the keys in storage, in seconds. If None, the entry
|
||||
will never expire.
|
||||
|
||||
max_entries : int or None
|
||||
The maximum number of entries to store in the cache storage.
|
||||
If None, the cache storage will not limit the number of entries.
|
||||
|
||||
persist : Literal["disk"] or None
|
||||
The persistence mode for the cache storage.
|
||||
Legacy parameter, that used in Streamlit current cache storage implementation.
|
||||
Could be ignored by cache storage implementation, if storage does not support
|
||||
persistence or it persistent by default.
|
||||
"""
|
||||
|
||||
function_key: str
|
||||
function_display_name: str
|
||||
ttl_seconds: float | None = None
|
||||
max_entries: int | None = None
|
||||
persist: Literal["disk"] | None = None
|
||||
|
||||
|
||||
class CacheStorage(Protocol):
|
||||
"""Cache storage protocol, that should be implemented by the concrete cache storages.
|
||||
Used to store cached values for a single `@st.cache_data` decorated function
|
||||
serialized as bytes.
|
||||
|
||||
CacheStorage instances should be created by `CacheStorageManager.create()` method.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: The methods of this protocol could be called from multiple threads.
|
||||
This is a responsibility of the concrete implementation to ensure thread safety
|
||||
guarantees.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str) -> bytes:
|
||||
"""Returns the stored value for the key.
|
||||
|
||||
Raises
|
||||
------
|
||||
CacheStorageKeyNotFoundError
|
||||
Raised if the key is not in the storage.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key: str, value: bytes) -> None:
|
||||
"""Sets the value for a given key."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
"""Delete a given key."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Remove all keys for the storage."""
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the cache storage, it is optional to implement, and should be used
|
||||
to close open resources, before we delete the storage instance.
|
||||
e.g. close the database connection etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class CacheStorageManager(Protocol):
|
||||
"""Cache storage manager protocol, that should be implemented by the concrete
|
||||
cache storage managers.
|
||||
|
||||
It is responsible for:
|
||||
- Creating cache storage instances for the specific
|
||||
decorated functions,
|
||||
- Validating the context for the cache storages.
|
||||
- Optionally clearing all cache storages in optimal way.
|
||||
|
||||
It should be created during Runtime initialization.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create(self, context: CacheStorageContext) -> CacheStorage:
|
||||
"""Creates a new cache storage instance
|
||||
Please note that the ttl, max_entries and other context fields are specific
|
||||
for whole storage, not for individual key.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: Should be safe to call from any thread.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Remove everything what possible from the cache storages in optimal way.
|
||||
meaningful default behaviour is to raise NotImplementedError, so this is not
|
||||
abstractmethod.
|
||||
|
||||
The method is optional to implement: cache data API will fall back to remove
|
||||
all available storages one by one via storage.clear() method
|
||||
if clear_all raises NotImplementedError.
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError
|
||||
Raised if the storage manager does not provide an ability to clear
|
||||
all storages at once in optimal way.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: This method could be called from multiple threads.
|
||||
This is a responsibility of the concrete implementation to ensure
|
||||
thread safety guarantees.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def check_context(self, context: CacheStorageContext) -> None:
|
||||
"""Checks if the context is valid for the storage manager.
|
||||
This method should not return anything, but log message or raise an exception
|
||||
if the context is invalid.
|
||||
|
||||
In case of raising an exception, we not handle it and let the exception to be
|
||||
propagated.
|
||||
|
||||
check_context is called only once at the moment of creating `@st.cache_data`
|
||||
decorator for specific function, so it is not called for every cache hit.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
context: CacheStorageContext
|
||||
The context to check for the storage manager, dummy function_key in context
|
||||
will be used, since it is not computed at the point of calling this method.
|
||||
|
||||
Raises
|
||||
------
|
||||
InvalidCacheStorageContext
|
||||
Raised if the cache storage manager is not able to work with provided
|
||||
CacheStorageContext. When possible we should log message instead, since
|
||||
this exception will be propagated to the user.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: Should be safe to call from any thread.
|
||||
"""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
CacheStorage,
|
||||
CacheStorageContext,
|
||||
CacheStorageKeyNotFoundError,
|
||||
CacheStorageManager,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.in_memory_cache_storage_wrapper import (
|
||||
InMemoryCacheStorageWrapper,
|
||||
)
|
||||
|
||||
|
||||
class MemoryCacheStorageManager(CacheStorageManager):
|
||||
def create(self, context: CacheStorageContext) -> CacheStorage:
|
||||
"""Creates a new cache storage instance wrapped with in-memory cache layer."""
|
||||
persist_storage = DummyCacheStorage()
|
||||
return InMemoryCacheStorageWrapper(
|
||||
persist_storage=persist_storage, context=context
|
||||
)
|
||||
|
||||
def clear_all(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def check_context(self, context: CacheStorageContext) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DummyCacheStorage(CacheStorage):
|
||||
def get(self, key: str) -> bytes:
|
||||
"""
|
||||
Dummy gets the value for a given key,
|
||||
always raises an CacheStorageKeyNotFoundError.
|
||||
"""
|
||||
raise CacheStorageKeyNotFoundError("Key not found in dummy cache")
|
||||
|
||||
def set(self, key: str, value: bytes) -> None:
|
||||
pass
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
@@ -0,0 +1,145 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import threading
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.caching import cache_utils
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
CacheStorage,
|
||||
CacheStorageContext,
|
||||
CacheStorageKeyNotFoundError,
|
||||
)
|
||||
from streamlit.runtime.stats import CacheStat
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
class InMemoryCacheStorageWrapper(CacheStorage):
|
||||
"""
|
||||
In-memory cache storage wrapper.
|
||||
|
||||
This class wraps a cache storage and adds an in-memory cache front layer,
|
||||
which is used to reduce the number of calls to the storage.
|
||||
|
||||
The in-memory cache is a TTL cache, which means that the entries are
|
||||
automatically removed if a given time to live (TTL) has passed.
|
||||
|
||||
The in-memory cache is also an LRU cache, which means that the entries
|
||||
are automatically removed if the cache size exceeds a given maxsize.
|
||||
|
||||
If the storage implements its strategy for maxsize, it is recommended
|
||||
(but not necessary) that the storage implement the same LRU strategy,
|
||||
otherwise a situation may arise when different items are deleted from
|
||||
the memory cache and from the storage.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: in-memory caching layer is thread safe: we hold self._mem_cache_lock for
|
||||
working with this self._mem_cache object.
|
||||
However, we do not hold this lock when calling into the underlying storage,
|
||||
so it is the responsibility of the that storage to ensure that it is safe to use
|
||||
it from multiple threads.
|
||||
"""
|
||||
|
||||
def __init__(self, persist_storage: CacheStorage, context: CacheStorageContext):
|
||||
self.function_key = context.function_key
|
||||
self.function_display_name = context.function_display_name
|
||||
self._ttl_seconds = context.ttl_seconds
|
||||
self._max_entries = context.max_entries
|
||||
self._mem_cache: TTLCache[str, bytes] = TTLCache(
|
||||
maxsize=self.max_entries,
|
||||
ttl=self.ttl_seconds,
|
||||
timer=cache_utils.TTLCACHE_TIMER,
|
||||
)
|
||||
self._mem_cache_lock = threading.Lock()
|
||||
self._persist_storage = persist_storage
|
||||
|
||||
@property
|
||||
def ttl_seconds(self) -> float:
|
||||
return self._ttl_seconds if self._ttl_seconds is not None else math.inf
|
||||
|
||||
@property
|
||||
def max_entries(self) -> float:
|
||||
return float(self._max_entries) if self._max_entries is not None else math.inf
|
||||
|
||||
def get(self, key: str) -> bytes:
|
||||
"""
|
||||
Returns the stored value for the key or raise CacheStorageKeyNotFoundError if
|
||||
the key is not found.
|
||||
"""
|
||||
try:
|
||||
entry_bytes = self._read_from_mem_cache(key)
|
||||
except CacheStorageKeyNotFoundError:
|
||||
entry_bytes = self._persist_storage.get(key)
|
||||
self._write_to_mem_cache(key, entry_bytes)
|
||||
return entry_bytes
|
||||
|
||||
def set(self, key: str, value: bytes) -> None:
|
||||
"""Sets the value for a given key."""
|
||||
self._write_to_mem_cache(key, value)
|
||||
self._persist_storage.set(key, value)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
"""Delete a given key."""
|
||||
self._remove_from_mem_cache(key)
|
||||
self._persist_storage.delete(key)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Delete all keys for the in memory cache, and also the persistent storage."""
|
||||
with self._mem_cache_lock:
|
||||
self._mem_cache.clear()
|
||||
self._persist_storage.clear()
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
"""Returns a list of stats in bytes for the cache memory storage per item."""
|
||||
stats = []
|
||||
|
||||
with self._mem_cache_lock:
|
||||
for item in self._mem_cache.values():
|
||||
stats.append(
|
||||
CacheStat(
|
||||
category_name="st_cache_data",
|
||||
cache_name=self.function_display_name,
|
||||
byte_length=len(item),
|
||||
)
|
||||
)
|
||||
return stats
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the cache storage."""
|
||||
self._persist_storage.close()
|
||||
|
||||
def _read_from_mem_cache(self, key: str) -> bytes:
|
||||
with self._mem_cache_lock:
|
||||
if key in self._mem_cache:
|
||||
entry = bytes(self._mem_cache[key])
|
||||
_LOGGER.debug("Memory cache HIT: %s", key)
|
||||
return entry
|
||||
|
||||
else:
|
||||
_LOGGER.debug("Memory cache MISS: %s", key)
|
||||
raise CacheStorageKeyNotFoundError("Key not found in mem cache")
|
||||
|
||||
def _write_to_mem_cache(self, key: str, entry_bytes: bytes) -> None:
|
||||
with self._mem_cache_lock:
|
||||
self._mem_cache[key] = entry_bytes
|
||||
|
||||
def _remove_from_mem_cache(self, key: str) -> None:
|
||||
with self._mem_cache_lock:
|
||||
self._mem_cache.pop(key, None)
|
||||
@@ -0,0 +1,223 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Declares the LocalDiskCacheStorageManager class, which is used
|
||||
to create LocalDiskCacheStorage instances wrapped by InMemoryCacheStorageWrapper,
|
||||
InMemoryCacheStorageWrapper wrapper allows to have first layer of in-memory cache,
|
||||
before accessing to LocalDiskCacheStorage itself.
|
||||
|
||||
Declares the LocalDiskCacheStorage class, which is used to store cached
|
||||
values on disk.
|
||||
|
||||
How these classes work together
|
||||
-------------------------------
|
||||
|
||||
- LocalDiskCacheStorageManager : each instance of this is able
|
||||
to create LocalDiskCacheStorage instances wrapped by InMemoryCacheStorageWrapper,
|
||||
and to clear data from cache storage folder. It is also LocalDiskCacheStorageManager
|
||||
responsibility to check if the context is valid for the storage, and to log warning
|
||||
if the context is not valid.
|
||||
|
||||
- LocalDiskCacheStorage : each instance of this is able to get, set, delete, and clear
|
||||
entries from disk for a single `@st.cache_data` decorated function if `persist="disk"`
|
||||
is used in CacheStorageContext.
|
||||
|
||||
|
||||
┌───────────────────────────────┐
|
||||
│ LocalDiskCacheStorageManager │
|
||||
│ │
|
||||
│ - clear_all │
|
||||
│ - check_context │
|
||||
│ │
|
||||
└──┬────────────────────────────┘
|
||||
│
|
||||
│ ┌──────────────────────────────┐
|
||||
│ │ │
|
||||
│ create(context)│ InMemoryCacheStorageWrapper │
|
||||
└────────────────► │
|
||||
│ ┌─────────────────────┐ │
|
||||
│ │ │ │
|
||||
│ │ LocalDiskStorage │ │
|
||||
│ │ │ │
|
||||
│ └─────────────────────┘ │
|
||||
│ │
|
||||
└──────────────────────────────┘
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from typing import Final
|
||||
|
||||
from streamlit import errors
|
||||
from streamlit.file_util import get_streamlit_file_path, streamlit_read, streamlit_write
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
CacheStorage,
|
||||
CacheStorageContext,
|
||||
CacheStorageError,
|
||||
CacheStorageKeyNotFoundError,
|
||||
CacheStorageManager,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.in_memory_cache_storage_wrapper import (
|
||||
InMemoryCacheStorageWrapper,
|
||||
)
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
# Streamlit directory where persisted @st.cache_data objects live.
|
||||
# (This is the same directory that @st.cache persisted objects live.
|
||||
# But @st.cache_data uses a different extension, so they don't overlap.)
|
||||
_CACHE_DIR_NAME: Final = "cache"
|
||||
|
||||
# The extension for our persisted @st.cache_data objects.
|
||||
# (`@st.cache_data` was originally called `@st.memo`)
|
||||
_CACHED_FILE_EXTENSION: Final = "memo"
|
||||
|
||||
|
||||
class LocalDiskCacheStorageManager(CacheStorageManager):
|
||||
def create(self, context: CacheStorageContext) -> CacheStorage:
|
||||
"""Creates a new cache storage instance wrapped with in-memory cache layer."""
|
||||
persist_storage = LocalDiskCacheStorage(context)
|
||||
return InMemoryCacheStorageWrapper(
|
||||
persist_storage=persist_storage, context=context
|
||||
)
|
||||
|
||||
def clear_all(self) -> None:
|
||||
cache_path = get_cache_folder_path()
|
||||
if os.path.isdir(cache_path):
|
||||
shutil.rmtree(cache_path)
|
||||
|
||||
def check_context(self, context: CacheStorageContext) -> None:
|
||||
if (
|
||||
context.persist == "disk"
|
||||
and context.ttl_seconds is not None
|
||||
and not math.isinf(context.ttl_seconds)
|
||||
):
|
||||
_LOGGER.warning(
|
||||
f"The cached function '{context.function_display_name}' has a TTL "
|
||||
"that will be ignored. Persistent cached functions currently don't "
|
||||
"support TTL."
|
||||
)
|
||||
|
||||
|
||||
class LocalDiskCacheStorage(CacheStorage):
|
||||
"""Cache storage that persists data to disk
|
||||
This is the default cache persistence layer for `@st.cache_data`.
|
||||
"""
|
||||
|
||||
def __init__(self, context: CacheStorageContext):
|
||||
self.function_key = context.function_key
|
||||
self.persist = context.persist
|
||||
self._ttl_seconds = context.ttl_seconds
|
||||
self._max_entries = context.max_entries
|
||||
|
||||
@property
|
||||
def ttl_seconds(self) -> float:
|
||||
return self._ttl_seconds if self._ttl_seconds is not None else math.inf
|
||||
|
||||
@property
|
||||
def max_entries(self) -> float:
|
||||
return float(self._max_entries) if self._max_entries is not None else math.inf
|
||||
|
||||
def get(self, key: str) -> bytes:
|
||||
"""
|
||||
Returns the stored value for the key if persisted,
|
||||
raise CacheStorageKeyNotFoundError if not found, or not configured
|
||||
with persist="disk".
|
||||
"""
|
||||
if self.persist == "disk":
|
||||
path = self._get_cache_file_path(key)
|
||||
try:
|
||||
with streamlit_read(path, binary=True) as input:
|
||||
value = input.read()
|
||||
_LOGGER.debug("Disk cache HIT: %s", key)
|
||||
return bytes(value)
|
||||
except FileNotFoundError:
|
||||
raise CacheStorageKeyNotFoundError("Key not found in disk cache")
|
||||
except Exception as ex:
|
||||
_LOGGER.exception("Error reading from cache")
|
||||
raise CacheStorageError("Unable to read from cache") from ex
|
||||
else:
|
||||
raise CacheStorageKeyNotFoundError(
|
||||
f"Local disk cache storage is disabled (persist={self.persist})"
|
||||
)
|
||||
|
||||
def set(self, key: str, value: bytes) -> None:
|
||||
"""Sets the value for a given key."""
|
||||
if self.persist == "disk":
|
||||
path = self._get_cache_file_path(key)
|
||||
try:
|
||||
with streamlit_write(path, binary=True) as output:
|
||||
output.write(value)
|
||||
except errors.Error as ex:
|
||||
_LOGGER.debug("Unable to write to cache", exc_info=ex)
|
||||
# Clean up file so we don't leave zero byte files.
|
||||
try:
|
||||
os.remove(path)
|
||||
except (FileNotFoundError, OSError):
|
||||
# If we can't remove the file, it's not a big deal.
|
||||
pass
|
||||
raise CacheStorageError("Unable to write to cache") from ex
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
"""Delete a cache file from disk. If the file does not exist on disk,
|
||||
return silently. If another exception occurs, log it. Does not throw.
|
||||
"""
|
||||
if self.persist == "disk":
|
||||
path = self._get_cache_file_path(key)
|
||||
try:
|
||||
os.remove(path)
|
||||
except FileNotFoundError:
|
||||
# The file is already removed.
|
||||
pass
|
||||
except Exception as ex:
|
||||
_LOGGER.exception(
|
||||
"Unable to remove a file from the disk cache", exc_info=ex
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Delete all keys for the current storage."""
|
||||
cache_dir = get_cache_folder_path()
|
||||
|
||||
if os.path.isdir(cache_dir):
|
||||
# We try to remove all files in the cache directory that start with
|
||||
# the function key, whether `clear` called for `self.persist`
|
||||
# storage or not, to avoid leaving orphaned files in the cache directory.
|
||||
for file_name in os.listdir(cache_dir):
|
||||
if self._is_cache_file(file_name):
|
||||
os.remove(os.path.join(cache_dir, file_name))
|
||||
|
||||
def close(self) -> None:
|
||||
"""Dummy implementation of close, we don't need to actually "close" anything."""
|
||||
|
||||
def _get_cache_file_path(self, value_key: str) -> str:
|
||||
"""Return the path of the disk cache file for the given value."""
|
||||
cache_dir = get_cache_folder_path()
|
||||
return os.path.join(
|
||||
cache_dir, f"{self.function_key}-{value_key}.{_CACHED_FILE_EXTENSION}"
|
||||
)
|
||||
|
||||
def _is_cache_file(self, fname: str) -> bool:
|
||||
"""Return true if the given file name is a cache file for this storage."""
|
||||
return fname.startswith(f"{self.function_key}-") and fname.endswith(
|
||||
f".{_CACHED_FILE_EXTENSION}"
|
||||
)
|
||||
|
||||
|
||||
def get_cache_folder_path() -> str:
|
||||
return get_streamlit_file_path(_CACHE_DIR_NAME)
|
||||
@@ -0,0 +1,435 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal, TypeVar, overload
|
||||
|
||||
from streamlit.connections import (
|
||||
BaseConnection,
|
||||
SnowflakeConnection,
|
||||
SnowparkConnection,
|
||||
SQLConnection,
|
||||
)
|
||||
from streamlit.deprecation_util import deprecate_obj_name
|
||||
from streamlit.errors import StreamlitAPIException
|
||||
from streamlit.runtime.caching import cache_resource
|
||||
from streamlit.runtime.metrics_util import gather_metrics
|
||||
from streamlit.runtime.secrets import secrets_singleton
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
# NOTE: Adding support for a new first party connection requires:
|
||||
# 1. Adding the new connection name and class to this dict.
|
||||
# 2. Writing two new @overloads for connection_factory (one for the case where the
|
||||
# only the connection name is specified and another when both name and type are).
|
||||
# 3. Updating test_get_first_party_connection_helper in connection_factory_test.py.
|
||||
FIRST_PARTY_CONNECTIONS = {
|
||||
"snowflake": SnowflakeConnection,
|
||||
"snowpark": SnowparkConnection,
|
||||
"sql": SQLConnection,
|
||||
}
|
||||
MODULE_EXTRACTION_REGEX = re.compile(r"No module named \'(.+)\'")
|
||||
MODULES_TO_PYPI_PACKAGES: Final[dict[str, str]] = {
|
||||
"MySQLdb": "mysqlclient",
|
||||
"psycopg2": "psycopg2-binary",
|
||||
"sqlalchemy": "sqlalchemy",
|
||||
"snowflake": "snowflake-connector-python",
|
||||
"snowflake.connector": "snowflake-connector-python",
|
||||
"snowflake.snowpark": "snowflake-snowpark-python",
|
||||
}
|
||||
|
||||
# The BaseConnection bound is parameterized to `Any` below as subclasses of
|
||||
# BaseConnection are responsible for binding the type parameter of BaseConnection to a
|
||||
# concrete type, but the type it gets bound to isn't important to us here.
|
||||
ConnectionClass = TypeVar("ConnectionClass", bound=BaseConnection[Any])
|
||||
|
||||
|
||||
@gather_metrics("connection")
|
||||
def _create_connection(
|
||||
name: str,
|
||||
connection_class: type[ConnectionClass],
|
||||
max_entries: int | None = None,
|
||||
ttl: float | timedelta | None = None,
|
||||
**kwargs,
|
||||
) -> ConnectionClass:
|
||||
"""Create an instance of connection_class with the given name and kwargs.
|
||||
|
||||
The weird implementation of this function with the @cache_resource annotated
|
||||
function defined internally is done to:
|
||||
- Always @gather_metrics on the call even if the return value is a cached one.
|
||||
- Allow the user to specify ttl and max_entries when calling st.connection.
|
||||
"""
|
||||
|
||||
def __create_connection(
|
||||
name: str, connection_class: type[ConnectionClass], **kwargs
|
||||
) -> ConnectionClass:
|
||||
return connection_class(connection_name=name, **kwargs)
|
||||
|
||||
if not issubclass(connection_class, BaseConnection):
|
||||
raise StreamlitAPIException(
|
||||
f"{connection_class} is not a subclass of BaseConnection!"
|
||||
)
|
||||
|
||||
# We modify our helper function's `__qualname__` here to work around default
|
||||
# `@st.cache_resource` behavior. Otherwise, `st.connection` being called with
|
||||
# different `ttl` or `max_entries` values will reset the cache with each call.
|
||||
ttl_str = str(ttl).replace( # Avoid adding extra `.` characters to `__qualname__`
|
||||
".", "_"
|
||||
)
|
||||
__create_connection.__qualname__ = (
|
||||
f"{__create_connection.__qualname__}_{ttl_str}_{max_entries}"
|
||||
)
|
||||
__create_connection = cache_resource(
|
||||
max_entries=max_entries,
|
||||
show_spinner="Running `st.connection(...)`.",
|
||||
ttl=ttl,
|
||||
)(__create_connection)
|
||||
|
||||
return __create_connection(name, connection_class, **kwargs)
|
||||
|
||||
|
||||
def _get_first_party_connection(connection_class: str):
|
||||
if connection_class in FIRST_PARTY_CONNECTIONS:
|
||||
return FIRST_PARTY_CONNECTIONS[connection_class]
|
||||
|
||||
raise StreamlitAPIException(
|
||||
f"Invalid connection '{connection_class}'. "
|
||||
f"Supported connection classes: {FIRST_PARTY_CONNECTIONS}"
|
||||
)
|
||||
|
||||
|
||||
@overload
|
||||
def connection_factory(
|
||||
name: Literal["sql"],
|
||||
max_entries: int | None = None,
|
||||
ttl: float | timedelta | None = None,
|
||||
autocommit: bool = False,
|
||||
**kwargs,
|
||||
) -> SQLConnection:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def connection_factory(
|
||||
name: str,
|
||||
type: Literal["sql"],
|
||||
max_entries: int | None = None,
|
||||
ttl: float | timedelta | None = None,
|
||||
autocommit: bool = False,
|
||||
**kwargs,
|
||||
) -> SQLConnection:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def connection_factory(
|
||||
name: Literal["snowflake"],
|
||||
max_entries: int | None = None,
|
||||
ttl: float | timedelta | None = None,
|
||||
autocommit: bool = False,
|
||||
**kwargs,
|
||||
) -> SnowflakeConnection:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def connection_factory(
|
||||
name: str,
|
||||
type: Literal["snowflake"],
|
||||
max_entries: int | None = None,
|
||||
ttl: float | timedelta | None = None,
|
||||
autocommit: bool = False,
|
||||
**kwargs,
|
||||
) -> SnowflakeConnection:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def connection_factory(
|
||||
name: Literal["snowpark"],
|
||||
max_entries: int | None = None,
|
||||
ttl: float | timedelta | None = None,
|
||||
**kwargs,
|
||||
) -> SnowparkConnection:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def connection_factory(
|
||||
name: str,
|
||||
type: Literal["snowpark"],
|
||||
max_entries: int | None = None,
|
||||
ttl: float | timedelta | None = None,
|
||||
**kwargs,
|
||||
) -> SnowparkConnection:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def connection_factory(
|
||||
name: str,
|
||||
type: type[ConnectionClass],
|
||||
max_entries: int | None = None,
|
||||
ttl: float | timedelta | None = None,
|
||||
**kwargs,
|
||||
) -> ConnectionClass:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def connection_factory(
|
||||
name: str,
|
||||
type: str | None = None,
|
||||
max_entries: int | None = None,
|
||||
ttl: float | timedelta | None = None,
|
||||
**kwargs,
|
||||
) -> BaseConnection[Any]:
|
||||
pass
|
||||
|
||||
|
||||
def connection_factory(
|
||||
name,
|
||||
type=None,
|
||||
max_entries=None,
|
||||
ttl=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new connection to a data store or API, or return an existing one.
|
||||
|
||||
Configuration options, credentials, and secrets for connections are
|
||||
combined from the following sources:
|
||||
|
||||
- The keyword arguments passed to this command.
|
||||
- The app's ``secrets.toml`` files.
|
||||
- Any connection-specific configuration files.
|
||||
|
||||
The connection returned from ``st.connection`` is internally cached with
|
||||
``st.cache_resource`` and is therefore shared between sessions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The connection name used for secrets lookup in ``secrets.toml``.
|
||||
Streamlit uses secrets under ``[connections.<name>]`` for the
|
||||
connection. ``type`` will be inferred if ``name`` is one of the
|
||||
following: ``"snowflake"``, ``"snowpark"``, or ``"sql"``.
|
||||
|
||||
type : str, connection class, or None
|
||||
The type of connection to create. This can be one of the following:
|
||||
|
||||
- ``None`` (default): Streamlit will infer the connection type from
|
||||
``name``. If the type is not inferrable from ``name``, the type must
|
||||
be specified in ``secrets.toml`` instead.
|
||||
- ``"snowflake"``: Streamlit will initialize a connection with
|
||||
|SnowflakeConnection|_.
|
||||
- ``"snowpark"``: Streamlit will initialize a connection with
|
||||
|SnowparkConnection|_. This is deprecated.
|
||||
- ``"sql"``: Streamlit will initialize a connection with
|
||||
|SQLConnection|_.
|
||||
- A string path to an importable class: This must be a dot-separated
|
||||
module path ending in the importable class. Streamlit will import the
|
||||
class and initialize a connection with it. The class must extend
|
||||
``st.connections.BaseConnection``.
|
||||
- An imported class reference: Streamlit will initialize a connection
|
||||
with the referenced class, which must extend
|
||||
``st.connections.BaseConnection``.
|
||||
|
||||
.. |SnowflakeConnection| replace:: ``SnowflakeConnection``
|
||||
.. _SnowflakeConnection: https://docs.streamlit.io/develop/api-reference/connections/st.connections.snowflakeconnection
|
||||
.. |SnowparkConnection| replace:: ``SnowparkConnection``
|
||||
.. _SnowparkConnection: https://docs.streamlit.io/develop/api-reference/connections/st.connections.snowparkconnection
|
||||
.. |SQLConnection| replace:: ``SQLConnection``
|
||||
.. _SQLConnection: https://docs.streamlit.io/develop/api-reference/connections/st.connections.sqlconnection
|
||||
|
||||
max_entries : int or None
|
||||
The maximum number of connections to keep in the cache.
|
||||
If this is ``None`` (default), the cache is unbounded. Otherwise, when
|
||||
a new entry is added to a full cache, the oldest cached entry is
|
||||
removed.
|
||||
ttl : float, timedelta, or None
|
||||
The maximum number of seconds to keep results in the cache.
|
||||
If this is ``None`` (default), cached results do not expire with time.
|
||||
**kwargs : any
|
||||
Connection-specific keyword arguments that are passed to the
|
||||
connection's ``._connect()`` method. ``**kwargs`` are typically
|
||||
combined with (and take precendence over) key-value pairs in
|
||||
``secrets.toml``. To learn more, see the specific connection's
|
||||
documentation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Subclass of BaseConnection
|
||||
An initialized connection object of the specified ``type``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
**Example 1: Inferred connection type**
|
||||
|
||||
The easiest way to create a first-party (SQL, Snowflake, or Snowpark) connection is
|
||||
to use their default names and define corresponding sections in your ``secrets.toml``
|
||||
file. The following example creates a ``"sql"``-type connection.
|
||||
|
||||
``.streamlit/secrets.toml``:
|
||||
|
||||
>>> [connections.sql]
|
||||
>>> dialect = "xxx"
|
||||
>>> host = "xxx"
|
||||
>>> username = "xxx"
|
||||
>>> password = "xxx"
|
||||
|
||||
Your app code:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>> conn = st.connection("sql")
|
||||
|
||||
**Example 2: Named connections**
|
||||
|
||||
Creating a connection with a custom name requires you to explicitly
|
||||
specify the type. If ``type`` is not passed as a keyword argument, it must
|
||||
be set in the appropriate section of ``secrets.toml``. The following
|
||||
example creates two ``"sql"``-type connections, each with their own
|
||||
custom name. The first defines ``type`` in the ``st.connection`` command;
|
||||
the second defines ``type`` in ``secrets.toml``.
|
||||
|
||||
``.streamlit/secrets.toml``:
|
||||
|
||||
>>> [connections.first_connection]
|
||||
>>> dialect = "xxx"
|
||||
>>> host = "xxx"
|
||||
>>> username = "xxx"
|
||||
>>> password = "xxx"
|
||||
>>>
|
||||
>>> [connections.second_connection]
|
||||
>>> type = "sql"
|
||||
>>> dialect = "yyy"
|
||||
>>> host = "yyy"
|
||||
>>> username = "yyy"
|
||||
>>> password = "yyy"
|
||||
|
||||
Your app code:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>> conn1 = st.connection("first_connection", type="sql")
|
||||
>>> conn2 = st.connection("second_connection")
|
||||
|
||||
**Example 3: Using a path to the connection class**
|
||||
|
||||
Passing the full module path to the connection class can be useful,
|
||||
especially when working with a custom connection. Although this is not the
|
||||
typical way to create first party connections, the following example
|
||||
creates the same type of connection as one with ``type="sql"``. Note that
|
||||
``type`` is a string path.
|
||||
|
||||
``.streamlit/secrets.toml``:
|
||||
|
||||
>>> [connections.my_sql_connection]
|
||||
>>> url = "xxx+xxx://xxx:xxx@xxx:xxx/xxx"
|
||||
|
||||
Your app code:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>> conn = st.connection(
|
||||
... "my_sql_connection", type="streamlit.connections.SQLConnection"
|
||||
... )
|
||||
|
||||
**Example 4: Importing the connection class**
|
||||
|
||||
You can pass the connection class directly to the ``st.connection``
|
||||
command. Doing so allows static type checking tools such as ``mypy`` to
|
||||
infer the exact return type of ``st.connection``. The following example
|
||||
creates the same connection as in Example 3.
|
||||
|
||||
``.streamlit/secrets.toml``:
|
||||
|
||||
>>> [connections.my_sql_connection]
|
||||
>>> url = "xxx+xxx://xxx:xxx@xxx:xxx/xxx"
|
||||
|
||||
Your app code:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>> from streamlit.connections import SQLConnection
|
||||
>>> conn = st.connection("my_sql_connection", type=SQLConnection)
|
||||
|
||||
"""
|
||||
USE_ENV_PREFIX = "env:"
|
||||
|
||||
if name.startswith(USE_ENV_PREFIX):
|
||||
# It'd be nice to use str.removeprefix() here, but we won't be able to do that
|
||||
# until the minimium Python version we support is 3.9.
|
||||
envvar_name = name[len(USE_ENV_PREFIX) :]
|
||||
name = os.environ[envvar_name]
|
||||
|
||||
if type is None:
|
||||
if name in FIRST_PARTY_CONNECTIONS:
|
||||
# We allow users to simply write `st.connection("sql")` instead of
|
||||
# `st.connection("sql", type="sql")`.
|
||||
type = _get_first_party_connection(name)
|
||||
else:
|
||||
# The user didn't specify a type, so we try to pull it out from their
|
||||
# secrets.toml file. NOTE: we're okay with any of the dict lookups below
|
||||
# exploding with a KeyError since, if type isn't explicitly specified here,
|
||||
# it must be the case that it's defined in secrets.toml and should raise an
|
||||
# Exception otherwise.
|
||||
secrets_singleton.load_if_toml_exists()
|
||||
type = secrets_singleton["connections"][name]["type"]
|
||||
|
||||
# type is a nice kwarg name for the st.connection user but is annoying to work with
|
||||
# since it conflicts with the builtin function name and thus gets syntax
|
||||
# highlighted.
|
||||
connection_class = type
|
||||
|
||||
if isinstance(connection_class, str):
|
||||
# We assume that a connection_class specified via string is either the fully
|
||||
# qualified name of a class (its module and exported classname) or the string
|
||||
# literal shorthand for one of our first party connections. In the former case,
|
||||
# connection_class will always contain a "." in its name.
|
||||
if "." in connection_class:
|
||||
parts = connection_class.split(".")
|
||||
classname = parts.pop()
|
||||
|
||||
import importlib
|
||||
|
||||
connection_module = importlib.import_module(".".join(parts))
|
||||
connection_class = getattr(connection_module, classname)
|
||||
else:
|
||||
connection_class = _get_first_party_connection(connection_class)
|
||||
|
||||
# At this point, connection_class should be of type Type[ConnectionClass].
|
||||
try:
|
||||
conn = _create_connection(
|
||||
name, connection_class, max_entries=max_entries, ttl=ttl, **kwargs
|
||||
)
|
||||
if isinstance(conn, SnowparkConnection):
|
||||
conn = deprecate_obj_name(
|
||||
conn,
|
||||
'connection("snowpark")',
|
||||
'connection("snowflake")',
|
||||
"2024-04-01",
|
||||
)
|
||||
return conn
|
||||
except ModuleNotFoundError as e:
|
||||
err_string = str(e)
|
||||
missing_module = re.search(MODULE_EXTRACTION_REGEX, err_string)
|
||||
|
||||
extra_info = "You may be missing a dependency required to use this connection."
|
||||
if missing_module:
|
||||
pypi_package = MODULES_TO_PYPI_PACKAGES.get(missing_module.group(1))
|
||||
if pypi_package:
|
||||
extra_info = f"You need to install the '{pypi_package}' package to use this connection."
|
||||
|
||||
raise ModuleNotFoundError(f"{str(e)}. {extra_info}")
|
||||
300
myenv/lib/python3.11/site-packages/streamlit/runtime/context.py
Normal file
300
myenv/lib/python3.11/site-packages/streamlit/runtime/context.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Iterator, Mapping
|
||||
from functools import lru_cache
|
||||
from types import MappingProxyType
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from streamlit import runtime
|
||||
from streamlit.runtime.metrics_util import gather_metrics
|
||||
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from http.cookies import Morsel
|
||||
|
||||
from tornado.httputil import HTTPHeaders, HTTPServerRequest
|
||||
from tornado.web import RequestHandler
|
||||
|
||||
|
||||
def _get_request() -> HTTPServerRequest | None:
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is None:
|
||||
return None
|
||||
|
||||
session_client = runtime.get_instance().get_client(ctx.session_id)
|
||||
if session_client is None:
|
||||
return None
|
||||
|
||||
# We return websocket request only if session_client is an instance of
|
||||
# BrowserWebSocketHandler (which is True for the Streamlit open-source
|
||||
# implementation). For any other implementation, we return None.
|
||||
# We are not using `type_util.is_type` here to avoid circular import.
|
||||
if (
|
||||
f"{type(session_client).__module__}.{type(session_client).__qualname__}"
|
||||
!= "streamlit.web.server.browser_websocket_handler.BrowserWebSocketHandler"
|
||||
):
|
||||
return None
|
||||
|
||||
return cast("RequestHandler", session_client).request
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _normalize_header(name: str) -> str:
|
||||
"""Map a header name to Http-Header-Case.
|
||||
|
||||
>>> _normalize_header("coNtent-TYPE")
|
||||
'Content-Type'
|
||||
"""
|
||||
return "-".join(w.capitalize() for w in name.split("-"))
|
||||
|
||||
|
||||
class StreamlitHeaders(Mapping[str, str]):
|
||||
def __init__(self, headers: Iterable[tuple[str, str]]):
|
||||
dict_like_headers: dict[str, list[str]] = {}
|
||||
|
||||
for key, value in headers:
|
||||
header_value = dict_like_headers.setdefault(_normalize_header(key), [])
|
||||
header_value.append(value)
|
||||
|
||||
self._headers = dict_like_headers
|
||||
|
||||
@classmethod
|
||||
def from_tornado_headers(cls, tornado_headers: HTTPHeaders) -> StreamlitHeaders:
|
||||
return cls(tornado_headers.get_all())
|
||||
|
||||
def get_all(self, key: str) -> list[str]:
|
||||
return list(self._headers.get(_normalize_header(key), []))
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
try:
|
||||
return self._headers[_normalize_header(key)][0]
|
||||
except LookupError:
|
||||
raise KeyError(key) from None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Number of unique headers present in request."""
|
||||
return len(self._headers)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._headers)
|
||||
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
return {key: self[key] for key in self}
|
||||
|
||||
|
||||
class StreamlitCookies(Mapping[str, str]):
|
||||
def __init__(self, cookies: Mapping[str, str]):
|
||||
self._cookies = MappingProxyType(cookies)
|
||||
|
||||
@classmethod
|
||||
def from_tornado_cookies(
|
||||
cls, tornado_cookies: dict[str, Morsel[Any]]
|
||||
) -> StreamlitCookies:
|
||||
dict_like_cookies = {}
|
||||
for key, morsel in tornado_cookies.items():
|
||||
dict_like_cookies[key] = morsel.value
|
||||
return cls(dict_like_cookies)
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
return self._cookies[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Number of unique headers present in request."""
|
||||
return len(self._cookies)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._cookies)
|
||||
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
return dict(self._cookies)
|
||||
|
||||
|
||||
class ContextProxy:
|
||||
"""An interface to access user session context.
|
||||
|
||||
``st.context`` provides a read-only interface to access headers and cookies
|
||||
for the current user session.
|
||||
|
||||
Each property (``st.context.headers`` and ``st.context.cookies``) returns
|
||||
a dictionary of named values.
|
||||
|
||||
"""
|
||||
|
||||
@property
|
||||
@gather_metrics("context.headers")
|
||||
def headers(self) -> StreamlitHeaders:
|
||||
"""A read-only, dict-like object containing headers sent in the initial request.
|
||||
|
||||
Keys are case-insensitive and may be repeated. When keys are repeated,
|
||||
dict-like methods will only return the last instance of each key. Use
|
||||
``.get_all(key="your_repeated_key")`` to see all values if the same
|
||||
header is set multiple times.
|
||||
|
||||
Examples
|
||||
--------
|
||||
**Example 1: Access all available headers**
|
||||
|
||||
Show a dictionary of headers (with only the last instance of any
|
||||
repeated key):
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> st.context.headers
|
||||
|
||||
**Example 2: Access a specific header**
|
||||
|
||||
Show the value of a specific header (or the last instance if it's
|
||||
repeated):
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> st.context.headers["host"]
|
||||
|
||||
Show of list of all headers for a given key:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> st.context.headers.get_all("pragma")
|
||||
|
||||
"""
|
||||
# We have a docstring in line above as one-liner, to have a correct docstring
|
||||
# in the st.write(st,context) call.
|
||||
session_client_request = _get_request()
|
||||
|
||||
if session_client_request is None:
|
||||
return StreamlitHeaders({})
|
||||
|
||||
return StreamlitHeaders.from_tornado_headers(session_client_request.headers)
|
||||
|
||||
@property
|
||||
@gather_metrics("context.cookies")
|
||||
def cookies(self) -> StreamlitCookies:
|
||||
"""A read-only, dict-like object containing cookies sent in the initial request.
|
||||
|
||||
Examples
|
||||
--------
|
||||
**Example 1: Access all available cookies**
|
||||
|
||||
Show a dictionary of cookies:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> st.context.cookies
|
||||
|
||||
**Example 2: Access a specific cookie**
|
||||
|
||||
Show the value of a specific cookie:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> st.context.cookies["_ga"]
|
||||
|
||||
"""
|
||||
# We have a docstring in line above as one-liner, to have a correct docstring
|
||||
# in the st.write(st,context) call.
|
||||
session_client_request = _get_request()
|
||||
|
||||
if session_client_request is None:
|
||||
return StreamlitCookies({})
|
||||
|
||||
cookies = session_client_request.cookies
|
||||
return StreamlitCookies.from_tornado_cookies(cookies)
|
||||
|
||||
@property
|
||||
@gather_metrics("context.timezone")
|
||||
def timezone(self) -> str | None:
|
||||
"""The read-only timezone of the user's browser.
|
||||
|
||||
Example
|
||||
-------
|
||||
Access the user's timezone, and format a datetime to display locally:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>> from datetime import datetime, timezone
|
||||
>>> import pytz
|
||||
>>>
|
||||
>>> tz = st.context.timezone
|
||||
>>> tz_obj = pytz.timezone(tz)
|
||||
>>>
|
||||
>>> now = datetime.now(timezone.utc)
|
||||
>>>
|
||||
>>> f"The user's timezone is {tz}."
|
||||
>>> f"The UTC time is {now}."
|
||||
>>> f"The user's local time is {now.astimezone(tz_obj)}"
|
||||
|
||||
"""
|
||||
ctx = get_script_run_ctx()
|
||||
|
||||
if ctx is None or ctx.context_info is None:
|
||||
return None
|
||||
return ctx.context_info.timezone
|
||||
|
||||
@property
|
||||
@gather_metrics("context.timezone_offset")
|
||||
def timezone_offset(self) -> int | None:
|
||||
"""The read-only timezone offset of the user's browser.
|
||||
|
||||
Example
|
||||
-------
|
||||
Access the user's timezone offset, and format a datetime to display locally:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>> from datetime import datetime, timezone, timedelta
|
||||
>>>
|
||||
>>> tzoff = st.context.timezone_offset
|
||||
>>> tz_obj = timezone(-timedelta(minutes=tzoff))
|
||||
>>>
|
||||
>>> now = datetime.now(timezone.utc)
|
||||
>>>
|
||||
>>> f"The user's timezone is {tz}."
|
||||
>>> f"The UTC time is {now}."
|
||||
>>> f"The user's local time is {now.astimezone(tz_obj)}"
|
||||
|
||||
"""
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is None or ctx.context_info is None:
|
||||
return None
|
||||
return ctx.context_info.timezone_offset
|
||||
|
||||
@property
|
||||
@gather_metrics("context.locale")
|
||||
def locale(self) -> str | None:
|
||||
"""The read-only locale of the user's browser.
|
||||
|
||||
``st.context.locale`` returns the value of |navigator.language|_ from
|
||||
the user's DOM. This is a string representing the user's preferred
|
||||
language (e.g. "en-US").
|
||||
|
||||
.. |navigator.language| replace:: ``navigator.language``
|
||||
.. _navigator.language: https://developer.mozilla.org/en-US/docs/Web/API/Navigator/language
|
||||
|
||||
Example
|
||||
-------
|
||||
Access the user's locale to display locally:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> if st.context.locale == "fr-FR":
|
||||
>>> st.write("Bonjour!")
|
||||
>>> else:
|
||||
>>> st.write("Hello!")
|
||||
|
||||
"""
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is None or ctx.context_info is None:
|
||||
return None
|
||||
return ctx.context_info.locale
|
||||
@@ -0,0 +1,364 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Manage the user's Streamlit credentials."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import Final, NamedTuple, NoReturn
|
||||
from uuid import uuid4
|
||||
|
||||
from streamlit import cli_util, env_util, file_util, util
|
||||
from streamlit.logger import get_logger
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
|
||||
if env_util.IS_WINDOWS:
|
||||
_CONFIG_FILE_PATH = r"%userprofile%/.streamlit/config.toml"
|
||||
else:
|
||||
_CONFIG_FILE_PATH = "~/.streamlit/config.toml"
|
||||
|
||||
|
||||
class _Activation(NamedTuple):
|
||||
email: str | None # the user's email.
|
||||
is_valid: bool # whether the email is valid.
|
||||
|
||||
|
||||
def email_prompt() -> str:
|
||||
# Emoji can cause encoding errors on non-UTF-8 terminals
|
||||
# (See https://github.com/streamlit/streamlit/issues/2284.)
|
||||
# WT_SESSION is a Windows Terminal specific environment variable. If it exists,
|
||||
# we are on the latest Windows Terminal that supports emojis
|
||||
show_emoji = sys.stdout.encoding == "utf-8" and (
|
||||
not env_util.IS_WINDOWS or os.environ.get("WT_SESSION")
|
||||
)
|
||||
|
||||
# IMPORTANT: Break the text below at 80 chars.
|
||||
return f"""
|
||||
{"👋 " if show_emoji else ""}{cli_util.style_for_cli("Welcome to Streamlit!", bold=True)}
|
||||
|
||||
If you’d like to receive helpful onboarding emails, news, offers, promotions,
|
||||
and the occasional swag, please enter your email address below. Otherwise,
|
||||
leave this field blank.
|
||||
|
||||
{cli_util.style_for_cli("Email: ", fg="blue")}"""
|
||||
|
||||
|
||||
_TELEMETRY_HEADLESS_TEXT = """
|
||||
Collecting usage statistics. To deactivate, set browser.gatherUsageStats to false.
|
||||
"""
|
||||
|
||||
|
||||
def _send_email(email: str) -> None:
|
||||
"""Send the user's email for metrics, if submitted."""
|
||||
import requests
|
||||
|
||||
if email is None or "@" not in email:
|
||||
return
|
||||
|
||||
metrics_url = ""
|
||||
try:
|
||||
response_json = requests.get(
|
||||
"https://data.streamlit.io/metrics.json", timeout=2
|
||||
).json()
|
||||
metrics_url = response_json.get("url", "")
|
||||
except Exception:
|
||||
_LOGGER.error("Failed to fetch metrics URL")
|
||||
return
|
||||
|
||||
headers = {
|
||||
"accept": "*/*",
|
||||
"accept-language": "en-US,en;q=0.9",
|
||||
"content-type": "application/json",
|
||||
"origin": "localhost:8501",
|
||||
"referer": "localhost:8501/",
|
||||
}
|
||||
|
||||
data = {
|
||||
"anonymous_id": None,
|
||||
"messageId": str(uuid4()),
|
||||
"event": "submittedEmail",
|
||||
"author_email": email,
|
||||
"source": "provided_email",
|
||||
"type": "track",
|
||||
"userId": email,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
metrics_url,
|
||||
headers=headers,
|
||||
data=json.dumps(data).encode(),
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
class Credentials:
|
||||
"""Credentials class."""
|
||||
|
||||
_singleton: Credentials | None = None
|
||||
|
||||
@classmethod
|
||||
def get_current(cls):
|
||||
"""Return the singleton instance."""
|
||||
if cls._singleton is None:
|
||||
Credentials()
|
||||
|
||||
return Credentials._singleton
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize class."""
|
||||
if Credentials._singleton is not None:
|
||||
raise RuntimeError(
|
||||
"Credentials already initialized. Use .get_current() instead"
|
||||
)
|
||||
|
||||
self.activation = None
|
||||
self._conf_file = _get_credential_file_path()
|
||||
|
||||
Credentials._singleton = self
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
def load(self, auto_resolve: bool = False) -> None:
|
||||
"""Load from toml file."""
|
||||
if self.activation is not None:
|
||||
_LOGGER.error("Credentials already loaded. Not rereading file.")
|
||||
return
|
||||
|
||||
import toml
|
||||
|
||||
try:
|
||||
with open(self._conf_file) as f:
|
||||
data = toml.load(f).get("general")
|
||||
if data is None:
|
||||
raise Exception
|
||||
self.activation = _verify_email(data.get("email"))
|
||||
except FileNotFoundError:
|
||||
if auto_resolve:
|
||||
self.activate(show_instructions=not auto_resolve)
|
||||
return
|
||||
raise RuntimeError(
|
||||
'Credentials not found. Please run "streamlit activate".'
|
||||
)
|
||||
except Exception:
|
||||
if auto_resolve:
|
||||
self.reset()
|
||||
self.activate(show_instructions=not auto_resolve)
|
||||
return
|
||||
raise Exception(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
Unable to load credentials from %s.
|
||||
Run "streamlit reset" and try again.
|
||||
"""
|
||||
)
|
||||
% (self._conf_file)
|
||||
)
|
||||
|
||||
def _check_activated(self, auto_resolve: bool = True) -> None:
|
||||
"""Check if streamlit is activated.
|
||||
|
||||
Used by `streamlit run script.py`
|
||||
"""
|
||||
try:
|
||||
self.load(auto_resolve)
|
||||
except (Exception, RuntimeError) as e:
|
||||
_exit(str(e))
|
||||
|
||||
if self.activation is None or not self.activation.is_valid:
|
||||
_exit("Activation email not valid.")
|
||||
|
||||
@classmethod
|
||||
def reset(cls) -> None:
|
||||
"""Reset credentials by removing file.
|
||||
|
||||
This is used by `streamlit activate reset` in case a user wants
|
||||
to start over.
|
||||
"""
|
||||
c = Credentials.get_current()
|
||||
c.activation = None
|
||||
|
||||
try:
|
||||
os.remove(c._conf_file)
|
||||
except OSError:
|
||||
_LOGGER.exception("Error removing credentials file.")
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save to toml file and send email."""
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
if self.activation is None:
|
||||
return
|
||||
|
||||
# Create intermediate directories if necessary
|
||||
os.makedirs(os.path.dirname(self._conf_file), exist_ok=True)
|
||||
|
||||
# Write the file
|
||||
data = {"email": self.activation.email}
|
||||
|
||||
import toml
|
||||
|
||||
with open(self._conf_file, "w") as f:
|
||||
toml.dump({"general": data}, f)
|
||||
|
||||
try:
|
||||
_send_email(self.activation.email)
|
||||
except RequestException:
|
||||
_LOGGER.exception("Error saving email:")
|
||||
|
||||
def activate(self, show_instructions: bool = True) -> None:
|
||||
"""Activate Streamlit.
|
||||
|
||||
Used by `streamlit activate`.
|
||||
"""
|
||||
try:
|
||||
self.load()
|
||||
except RuntimeError:
|
||||
# Runtime Error is raised if credentials file is not found. In that case,
|
||||
# `self.activation` is None and we will show the activation prompt below.
|
||||
pass
|
||||
|
||||
if self.activation:
|
||||
if self.activation.is_valid:
|
||||
_exit("Already activated")
|
||||
else:
|
||||
_exit(
|
||||
"Activation not valid. Please run "
|
||||
"`streamlit activate reset` then `streamlit activate`"
|
||||
)
|
||||
else:
|
||||
activated = False
|
||||
|
||||
while not activated:
|
||||
import click
|
||||
|
||||
email = click.prompt(
|
||||
text=email_prompt(),
|
||||
prompt_suffix="",
|
||||
default="",
|
||||
show_default=False,
|
||||
)
|
||||
|
||||
self.activation = _verify_email(email)
|
||||
if self.activation.is_valid:
|
||||
self.save()
|
||||
# IMPORTANT: Break the text below at 80 chars.
|
||||
TELEMETRY_TEXT = """
|
||||
You can find our privacy policy at %(link)s
|
||||
|
||||
Summary:
|
||||
- This open source library collects usage statistics.
|
||||
- We cannot see and do not store information contained inside Streamlit apps,
|
||||
such as text, charts, images, etc.
|
||||
- Telemetry data is stored in servers in the United States.
|
||||
- If you'd like to opt out, add the following to %(config)s,
|
||||
creating that file if necessary:
|
||||
|
||||
[browser]
|
||||
gatherUsageStats = false
|
||||
""" % {
|
||||
"link": cli_util.style_for_cli(
|
||||
"https://streamlit.io/privacy-policy", underline=True
|
||||
),
|
||||
"config": cli_util.style_for_cli(_CONFIG_FILE_PATH),
|
||||
}
|
||||
|
||||
cli_util.print_to_cli(TELEMETRY_TEXT)
|
||||
if show_instructions:
|
||||
# IMPORTANT: Break the text below at 80 chars.
|
||||
INSTRUCTIONS_TEXT = """
|
||||
%(start)s
|
||||
%(prompt)s %(hello)s
|
||||
""" % {
|
||||
"start": cli_util.style_for_cli(
|
||||
"Get started by typing:", fg="blue", bold=True
|
||||
),
|
||||
"prompt": cli_util.style_for_cli("$", fg="blue"),
|
||||
"hello": cli_util.style_for_cli(
|
||||
"streamlit hello", bold=True
|
||||
),
|
||||
}
|
||||
|
||||
cli_util.print_to_cli(INSTRUCTIONS_TEXT)
|
||||
activated = True
|
||||
else: # pragma: nocover
|
||||
_LOGGER.error("Please try again.")
|
||||
|
||||
|
||||
def _verify_email(email: str) -> _Activation:
|
||||
"""Verify the user's email address.
|
||||
|
||||
The email can either be an empty string (if the user chooses not to enter
|
||||
it), or a string with a single '@' somewhere in it.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
email : str
|
||||
|
||||
Returns
|
||||
-------
|
||||
_Activation
|
||||
An _Activation object. Its 'is_valid' property will be True only if
|
||||
the email was validated.
|
||||
|
||||
"""
|
||||
email = email.strip()
|
||||
|
||||
# We deliberately use simple email validation here
|
||||
# since we do not use email address anywhere to send emails.
|
||||
if len(email) > 0 and email.count("@") != 1:
|
||||
_LOGGER.error("That doesn't look like an email :(")
|
||||
return _Activation(None, False)
|
||||
|
||||
return _Activation(email, True)
|
||||
|
||||
|
||||
def _exit(message: str) -> NoReturn:
|
||||
"""Exit program with error."""
|
||||
_LOGGER.error(message)
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def _get_credential_file_path() -> str:
|
||||
return file_util.get_streamlit_file_path("credentials.toml")
|
||||
|
||||
|
||||
def _check_credential_file_exists() -> bool:
|
||||
return os.path.exists(_get_credential_file_path())
|
||||
|
||||
|
||||
def check_credentials() -> None:
|
||||
"""Check credentials and potentially activate.
|
||||
|
||||
Note
|
||||
----
|
||||
If there is no credential file and we are in headless mode, we should not
|
||||
check, since credential would be automatically set to an empty string.
|
||||
|
||||
"""
|
||||
from streamlit import config
|
||||
|
||||
if not _check_credential_file_exists() and config.get_option("server.headless"):
|
||||
if not config.is_manually_set("browser.gatherUsageStats"):
|
||||
# If not manually defined, show short message about usage stats gathering.
|
||||
cli_util.print_to_cli(_TELEMETRY_HEADLESS_TEXT)
|
||||
return
|
||||
Credentials.get_current()._check_activated()
|
||||
@@ -0,0 +1,296 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Final
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
from streamlit import config, util
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import MutableMapping
|
||||
|
||||
from streamlit.runtime.app_session import AppSession
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
|
||||
def populate_hash_if_needed(msg: ForwardMsg) -> str:
|
||||
"""Computes and assigns the unique hash for a ForwardMsg.
|
||||
|
||||
If the ForwardMsg already has a hash, this is a no-op.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg : ForwardMsg
|
||||
|
||||
Returns
|
||||
-------
|
||||
string
|
||||
The message's hash, returned here for convenience. (The hash
|
||||
will also be assigned to the ForwardMsg; callers do not need
|
||||
to do this.)
|
||||
|
||||
"""
|
||||
if msg.hash == "":
|
||||
# Move the message's metadata aside. It's not part of the
|
||||
# hash calculation.
|
||||
metadata = msg.metadata
|
||||
msg.ClearField("metadata")
|
||||
|
||||
# MD5 is good enough for what we need, which is uniqueness.
|
||||
msg.hash = util.calc_md5(msg.SerializeToString())
|
||||
|
||||
# Restore metadata.
|
||||
msg.metadata.CopyFrom(metadata)
|
||||
|
||||
return msg.hash
|
||||
|
||||
|
||||
def create_reference_msg(msg: ForwardMsg) -> ForwardMsg:
|
||||
"""Create a ForwardMsg that refers to the given message via its hash.
|
||||
|
||||
The reference message will also get a copy of the source message's
|
||||
metadata.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg : ForwardMsg
|
||||
The ForwardMsg to create the reference to.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ForwardMsg
|
||||
A new ForwardMsg that "points" to the original message via the
|
||||
ref_hash field.
|
||||
|
||||
"""
|
||||
ref_msg = ForwardMsg()
|
||||
ref_msg.ref_hash = populate_hash_if_needed(msg)
|
||||
ref_msg.metadata.CopyFrom(msg.metadata)
|
||||
return ref_msg
|
||||
|
||||
|
||||
class ForwardMsgCache(CacheStatsProvider):
|
||||
"""A cache of ForwardMsgs.
|
||||
|
||||
Large ForwardMsgs (e.g. those containing big DataFrame payloads) are
|
||||
stored in this cache. The server can choose to send a ForwardMsg's hash,
|
||||
rather than the message itself, to a client. Clients can then
|
||||
request messages from this cache via another endpoint.
|
||||
|
||||
This cache is *not* thread safe. It's intended to only be accessed by
|
||||
the server thread.
|
||||
|
||||
"""
|
||||
|
||||
class Entry:
|
||||
"""Cache entry.
|
||||
|
||||
Stores the cached message, and the set of AppSessions
|
||||
that we've sent the cached message to.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, msg: ForwardMsg | None):
|
||||
self.msg = msg
|
||||
self._session_script_run_counts: MutableMapping[AppSession, int] = (
|
||||
WeakKeyDictionary()
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
def add_session_ref(self, session: AppSession, script_run_count: int) -> None:
|
||||
"""Adds a reference to a AppSession that has referenced
|
||||
this Entry's message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session : AppSession
|
||||
script_run_count : int
|
||||
The session's run count at the time of the call
|
||||
|
||||
"""
|
||||
prev_run_count = self._session_script_run_counts.get(session, 0)
|
||||
if script_run_count < prev_run_count:
|
||||
_LOGGER.error(
|
||||
"New script_run_count (%s) is < prev_run_count (%s). "
|
||||
"This should never happen!",
|
||||
script_run_count,
|
||||
prev_run_count,
|
||||
)
|
||||
script_run_count = prev_run_count
|
||||
self._session_script_run_counts[session] = script_run_count
|
||||
|
||||
def has_session_ref(self, session: AppSession) -> bool:
|
||||
return session in self._session_script_run_counts
|
||||
|
||||
def get_session_ref_age(
|
||||
self, session: AppSession, script_run_count: int
|
||||
) -> int:
|
||||
"""The age of the given session's reference to the Entry,
|
||||
given a new script_run_count.
|
||||
|
||||
"""
|
||||
return script_run_count - self._session_script_run_counts[session]
|
||||
|
||||
def remove_session_ref(self, session: AppSession) -> None:
|
||||
del self._session_script_run_counts[session]
|
||||
|
||||
def has_refs(self) -> bool:
|
||||
"""True if this Entry has references from any AppSession.
|
||||
|
||||
If not, it can be removed from the cache.
|
||||
"""
|
||||
return len(self._session_script_run_counts) > 0
|
||||
|
||||
def __init__(self):
|
||||
self._entries: dict[str, ForwardMsgCache.Entry] = {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
def add_message(
|
||||
self, msg: ForwardMsg, session: AppSession, script_run_count: int
|
||||
) -> None:
|
||||
"""Add a ForwardMsg to the cache.
|
||||
|
||||
The cache will also record a reference to the given AppSession,
|
||||
so that it can track which sessions have already received
|
||||
each given ForwardMsg.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg : ForwardMsg
|
||||
session : AppSession
|
||||
script_run_count : int
|
||||
The number of times the session's script has run
|
||||
|
||||
"""
|
||||
populate_hash_if_needed(msg)
|
||||
entry = self._entries.get(msg.hash, None)
|
||||
if entry is None:
|
||||
if config.get_option("global.storeCachedForwardMessagesInMemory"):
|
||||
entry = ForwardMsgCache.Entry(msg)
|
||||
else:
|
||||
entry = ForwardMsgCache.Entry(None)
|
||||
self._entries[msg.hash] = entry
|
||||
entry.add_session_ref(session, script_run_count)
|
||||
|
||||
def get_message(self, hash: str) -> ForwardMsg | None:
|
||||
"""Return the message with the given ID if it exists in the cache.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
hash : str
|
||||
The id of the message to retrieve.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ForwardMsg | None
|
||||
|
||||
"""
|
||||
entry = self._entries.get(hash, None)
|
||||
return entry.msg if entry else None
|
||||
|
||||
def has_message_reference(
|
||||
self, msg: ForwardMsg, session: AppSession, script_run_count: int
|
||||
) -> bool:
|
||||
"""Return True if a session has a reference to a message."""
|
||||
populate_hash_if_needed(msg)
|
||||
|
||||
entry = self._entries.get(msg.hash, None)
|
||||
if entry is None or not entry.has_session_ref(session):
|
||||
return False
|
||||
|
||||
# Ensure we're not expired
|
||||
age = entry.get_session_ref_age(session, script_run_count)
|
||||
return age <= int(config.get_option("global.maxCachedMessageAge"))
|
||||
|
||||
def remove_refs_for_session(self, session: AppSession) -> None:
|
||||
"""Remove refs for all entries for the given session.
|
||||
|
||||
This should be called when an AppSession is disconnected or closed.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session : AppSession
|
||||
"""
|
||||
|
||||
# Operate on a copy of our entries dict.
|
||||
# We may be deleting from it.
|
||||
for msg_hash, entry in self._entries.copy().items():
|
||||
if entry.has_session_ref(session):
|
||||
entry.remove_session_ref(session)
|
||||
|
||||
if not entry.has_refs():
|
||||
# The entry has no more references. Remove it from
|
||||
# the cache completely.
|
||||
del self._entries[msg_hash]
|
||||
|
||||
def remove_expired_entries_for_session(
|
||||
self, session: AppSession, script_run_count: int
|
||||
) -> None:
|
||||
"""Remove any cached messages that have expired from the given session.
|
||||
|
||||
This should be called each time a AppSession finishes executing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session : AppSession
|
||||
script_run_count : int
|
||||
The number of times the session's script has run
|
||||
|
||||
"""
|
||||
max_age = config.get_option("global.maxCachedMessageAge")
|
||||
|
||||
# Operate on a copy of our entries dict.
|
||||
# We may be deleting from it.
|
||||
for msg_hash, entry in self._entries.copy().items():
|
||||
if not entry.has_session_ref(session):
|
||||
continue
|
||||
|
||||
age = entry.get_session_ref_age(session, script_run_count)
|
||||
if age > max_age:
|
||||
_LOGGER.debug(
|
||||
"Removing expired entry [session=%s, hash=%s, age=%s]",
|
||||
id(session),
|
||||
msg_hash,
|
||||
age,
|
||||
)
|
||||
entry.remove_session_ref(session)
|
||||
if not entry.has_refs():
|
||||
# The entry has no more references. Remove it from
|
||||
# the cache completely.
|
||||
del self._entries[msg_hash]
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all entries from the cache."""
|
||||
self._entries.clear()
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
stats: list[CacheStat] = [
|
||||
CacheStat(
|
||||
category_name="ForwardMessageCache",
|
||||
cache_name="",
|
||||
byte_length=entry.msg.ByteSize() if entry.msg is not None else 0,
|
||||
)
|
||||
for _, entry in self._entries.items()
|
||||
]
|
||||
return group_stats(stats)
|
||||
@@ -0,0 +1,241 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.proto.Delta_pb2 import Delta
|
||||
|
||||
|
||||
class ForwardMsgQueue:
|
||||
"""Accumulates a session's outgoing ForwardMsgs.
|
||||
|
||||
Each AppSession adds messages to its queue, and the Server periodically
|
||||
flushes all session queues and delivers their messages to the appropriate
|
||||
clients.
|
||||
|
||||
ForwardMsgQueue is not thread-safe - a queue should only be used from
|
||||
a single thread.
|
||||
"""
|
||||
|
||||
_before_enqueue_msg: Callable[[ForwardMsg], None] | None = None
|
||||
|
||||
@staticmethod
|
||||
def on_before_enqueue_msg(
|
||||
before_enqueue_msg: Callable[[ForwardMsg], None] | None,
|
||||
) -> None:
|
||||
"""Set a callback to be called before a message is enqueued.
|
||||
Used in static streamlit app generation.
|
||||
"""
|
||||
ForwardMsgQueue._before_enqueue_msg = before_enqueue_msg
|
||||
|
||||
def __init__(self):
|
||||
self._queue: list[ForwardMsg] = []
|
||||
# A mapping of (delta_path -> _queue.indexof(msg)) for each
|
||||
# Delta message in the queue. We use this for coalescing
|
||||
# redundant outgoing Deltas (where a newer Delta supersedes
|
||||
# an older Delta, with the same delta_path, that's still in the
|
||||
# queue).
|
||||
self._delta_index_map: dict[tuple[int, ...], int] = {}
|
||||
|
||||
def get_debug(self) -> dict[str, Any]:
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
||||
return {
|
||||
"queue": [MessageToDict(m) for m in self._queue],
|
||||
"ids": list(self._delta_index_map.keys()),
|
||||
}
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return len(self._queue) == 0
|
||||
|
||||
def enqueue(self, msg: ForwardMsg) -> None:
|
||||
"""Add message into queue, possibly composing it with another message."""
|
||||
|
||||
if ForwardMsgQueue._before_enqueue_msg:
|
||||
ForwardMsgQueue._before_enqueue_msg(msg)
|
||||
|
||||
if not _is_composable_message(msg):
|
||||
self._queue.append(msg)
|
||||
return
|
||||
|
||||
# If there's a Delta message with the same delta_path already in
|
||||
# the queue - meaning that it refers to the same location in
|
||||
# the app - we attempt to combine this new Delta into the old
|
||||
# one. This is an optimization that prevents redundant Deltas
|
||||
# from being sent to the frontend.
|
||||
delta_key = tuple(msg.metadata.delta_path)
|
||||
if delta_key in self._delta_index_map:
|
||||
index = self._delta_index_map[delta_key]
|
||||
old_msg = self._queue[index]
|
||||
composed_delta = _maybe_compose_deltas(old_msg.delta, msg.delta)
|
||||
if composed_delta is not None:
|
||||
new_msg = ForwardMsg()
|
||||
new_msg.delta.CopyFrom(composed_delta)
|
||||
new_msg.metadata.CopyFrom(msg.metadata)
|
||||
self._queue[index] = new_msg
|
||||
return
|
||||
|
||||
# No composition occurred. Append this message to the queue, and
|
||||
# store its index for potential future composition.
|
||||
self._delta_index_map[delta_key] = len(self._queue)
|
||||
self._queue.append(msg)
|
||||
|
||||
def clear(
|
||||
self,
|
||||
retain_lifecycle_msgs: bool = False,
|
||||
fragment_ids_this_run: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Clear the queue, potentially retaining lifecycle messages.
|
||||
|
||||
The retain_lifecycle_msgs argument exists because in some cases (in particular
|
||||
when a currently running script is interrupted by a new BackMsg), we don't want
|
||||
to remove certain messages from the queue as doing so may cause the client to
|
||||
not hear about important script lifecycle events (such as the script being
|
||||
stopped early in order to be rerun).
|
||||
|
||||
If fragment_ids_this_run is provided, delta messages not belonging to any
|
||||
fragment or belonging to a fragment not in fragment_ids_this_run will be
|
||||
preserved to prevent clearing messages unrelated to the running fragments.
|
||||
"""
|
||||
|
||||
if not retain_lifecycle_msgs:
|
||||
self._queue = []
|
||||
else:
|
||||
self._queue = [
|
||||
_update_script_finished_message(msg, fragment_ids_this_run is not None)
|
||||
for msg in self._queue
|
||||
if msg.WhichOneof("type")
|
||||
in {
|
||||
"new_session",
|
||||
"script_finished",
|
||||
"session_status_changed",
|
||||
"parent_message",
|
||||
}
|
||||
or (
|
||||
# preserve all messages if this is a fragment rerun and...
|
||||
fragment_ids_this_run is not None
|
||||
and (
|
||||
# the message is not a delta message
|
||||
# (not associated with a fragment) or...
|
||||
msg.delta is None
|
||||
or (
|
||||
# it is a delta but not associated with any of the passed
|
||||
# fragments
|
||||
msg.delta is not None
|
||||
and (
|
||||
msg.delta.fragment_id is None
|
||||
or msg.delta.fragment_id not in fragment_ids_this_run
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
self._delta_index_map = {}
|
||||
|
||||
def flush(self) -> list[ForwardMsg]:
|
||||
"""Clear the queue and return a list of the messages it contained
|
||||
before being cleared.
|
||||
"""
|
||||
queue = self._queue
|
||||
self.clear()
|
||||
return queue
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._queue)
|
||||
|
||||
|
||||
def _is_composable_message(msg: ForwardMsg) -> bool:
|
||||
"""True if the ForwardMsg is potentially composable with other ForwardMsgs."""
|
||||
if not msg.HasField("delta"):
|
||||
# Non-delta messages are never composable.
|
||||
return False
|
||||
|
||||
# We never compose add_rows messages in Python, because the add_rows
|
||||
# operation can raise errors, and we don't have a good way of handling
|
||||
# those errors in the message queue.
|
||||
delta_type = msg.delta.WhichOneof("type")
|
||||
return delta_type != "add_rows" and delta_type != "arrow_add_rows"
|
||||
|
||||
|
||||
def _maybe_compose_deltas(old_delta: Delta, new_delta: Delta) -> Delta | None:
|
||||
"""Combines new_delta onto old_delta if possible.
|
||||
|
||||
If the combination takes place, the function returns a new Delta that
|
||||
should replace old_delta in the queue.
|
||||
|
||||
If the new_delta is incompatible with old_delta, the function returns None.
|
||||
In this case, the new_delta should just be appended to the queue as normal.
|
||||
"""
|
||||
old_delta_type = old_delta.WhichOneof("type")
|
||||
if old_delta_type == "add_block":
|
||||
# We never replace add_block deltas, because blocks can have
|
||||
# other dependent deltas later in the queue. For example:
|
||||
#
|
||||
# placeholder = st.empty()
|
||||
# placeholder.columns(1)
|
||||
# placeholder.empty()
|
||||
#
|
||||
# The call to "placeholder.columns(1)" creates two blocks, a parent
|
||||
# container with delta_path (0, 0), and a column child with
|
||||
# delta_path (0, 0, 0). If the final "placeholder.empty()" Delta
|
||||
# is composed with the parent container Delta, the frontend will
|
||||
# throw an error when it tries to add that column child to what is
|
||||
# now just an element, and not a block.
|
||||
return None
|
||||
|
||||
new_delta_type = new_delta.WhichOneof("type")
|
||||
if new_delta_type == "new_element":
|
||||
return new_delta
|
||||
|
||||
if new_delta_type == "add_block":
|
||||
return new_delta
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _update_script_finished_message(
|
||||
msg: ForwardMsg, is_fragment_run: bool
|
||||
) -> ForwardMsg:
|
||||
"""
|
||||
When we are here, the message queue is cleared from non-lifecycle messages
|
||||
before they were flushed to the browser.
|
||||
|
||||
If there were no non-lifecycle messages in the queue, changing the type here
|
||||
should not matter for the frontend anyways, so we optimistically change the
|
||||
`script_finished` message to `FINISHED_EARLY_FOR_RERUN`. This indicates to
|
||||
the frontend that the previous run was interrupted by a new script start.
|
||||
Otherwise, a `FINISHED_SUCCESSFULLY` message might trigger a reset of widget
|
||||
states on the frontend.
|
||||
"""
|
||||
if msg.WhichOneof("type") == "script_finished" and (
|
||||
# If this is not a fragment run (= full app run), its okay to change the
|
||||
# script_finished type to FINISHED_EARLY_FOR_RERUN because another full app run
|
||||
# is about to start.
|
||||
# If this is a fragment run, it is allowed to change the state of
|
||||
# all script_finished states except for FINISHED_SUCCESSFULLY, which we use to
|
||||
# indicate that a full app run has finished successfully (in other words, a
|
||||
# fragment should not modify the finished status of a full app run, because
|
||||
# the fragment finished state is different and the frontend might not trigger
|
||||
# cleanups etc. correctly).
|
||||
is_fragment_run is False
|
||||
or msg.script_finished != ForwardMsg.ScriptFinishedStatus.FINISHED_SUCCESSFULLY
|
||||
):
|
||||
msg.script_finished = ForwardMsg.ScriptFinishedStatus.FINISHED_EARLY_FOR_RERUN
|
||||
return msg
|
||||
478
myenv/lib/python3.11/site-packages/streamlit/runtime/fragment.py
Normal file
478
myenv/lib/python3.11/site-packages/streamlit/runtime/fragment.py
Normal file
@@ -0,0 +1,478 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
from abc import abstractmethod
|
||||
from copy import deepcopy
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Callable, Protocol, TypeVar, overload
|
||||
|
||||
from streamlit.deprecation_util import (
|
||||
make_deprecated_name_warning,
|
||||
show_deprecation_warning,
|
||||
)
|
||||
from streamlit.error_util import handle_uncaught_app_exception
|
||||
from streamlit.errors import FragmentHandledException, FragmentStorageKeyError
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
from streamlit.runtime.metrics_util import gather_metrics
|
||||
from streamlit.runtime.scriptrunner_utils.exceptions import (
|
||||
RerunException,
|
||||
StopException,
|
||||
)
|
||||
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx
|
||||
from streamlit.time_util import time_to_seconds
|
||||
from streamlit.util import calc_md5
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
Fragment = Callable[[], Any]
|
||||
|
||||
|
||||
class FragmentStorage(Protocol):
|
||||
"""A key-value store for Fragments. Used to implement the @st.fragment decorator.
|
||||
|
||||
We intentionally define this as its own protocol despite how generic it appears to
|
||||
be at first glance. The reason why is that, in any case where fragments aren't just
|
||||
stored as Python closures in memory, storing and retrieving Fragments will generally
|
||||
involve serializing and deserializing function bytecode, which is a tricky aspect
|
||||
to implementing FragmentStorages that won't generally appear with our other *Storage
|
||||
protocols.
|
||||
"""
|
||||
|
||||
# Weirdly, we have to define this above the `set` method, or mypy gets it confused
|
||||
# with the `set` type of `new_fragments_ids`.
|
||||
@abstractmethod
|
||||
def clear(self, new_fragment_ids: set[str] | None = None) -> None:
|
||||
"""Remove all fragments saved in this FragmentStorage unless listed in
|
||||
new_fragment_ids.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str) -> Fragment:
|
||||
"""Returns the stored fragment for the given key."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key: str, value: Fragment) -> None:
|
||||
"""Saves a fragment under the given key."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
"""Delete the fragment corresponding to the given key."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def contains(self, key: str) -> bool:
|
||||
"""Return whether the given key is present in this FragmentStorage."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# NOTE: Ideally, we'd like to add a MemoryFragmentStorageStatProvider implementation to
|
||||
# keep track of memory usage due to fragments, but doing something like this ends up
|
||||
# being difficult in practice as the memory usage of a closure is hard to measure (the
|
||||
# vendored implementation of pympler.asizeof that we use elsewhere is unable to measure
|
||||
# the size of a function).
|
||||
class MemoryFragmentStorage(FragmentStorage):
|
||||
"""A simple, memory-backed implementation of FragmentStorage.
|
||||
|
||||
MemoryFragmentStorage is just a wrapper around a plain Python dict that complies with
|
||||
the FragmentStorage protocol.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._fragments: dict[str, Fragment] = {}
|
||||
|
||||
# Weirdly, we have to define this above the `set` method, or mypy gets it confused
|
||||
# with the `set` type of `new_fragments_ids`.
|
||||
def clear(self, new_fragment_ids: set[str] | None = None) -> None:
|
||||
if new_fragment_ids is None:
|
||||
new_fragment_ids = set()
|
||||
|
||||
fragment_ids = list(self._fragments.keys())
|
||||
|
||||
for fid in fragment_ids:
|
||||
if fid not in new_fragment_ids:
|
||||
del self._fragments[fid]
|
||||
|
||||
def get(self, key: str) -> Fragment:
|
||||
try:
|
||||
return self._fragments[key]
|
||||
except KeyError as e:
|
||||
raise FragmentStorageKeyError(str(e))
|
||||
|
||||
def set(self, key: str, value: Fragment) -> None:
|
||||
self._fragments[key] = value
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
try:
|
||||
del self._fragments[key]
|
||||
except KeyError as e:
|
||||
raise FragmentStorageKeyError(str(e))
|
||||
|
||||
def contains(self, key: str) -> bool:
|
||||
return key in self._fragments
|
||||
|
||||
|
||||
def _fragment(
|
||||
func: F | None = None,
|
||||
*,
|
||||
run_every: int | float | timedelta | str | None = None,
|
||||
additional_hash_info: str = "",
|
||||
should_show_deprecation_warning: bool = False,
|
||||
) -> Callable[[F], F] | F:
|
||||
"""Contains the actual fragment logic.
|
||||
|
||||
This function should be used by our internal functions that use fragments
|
||||
under-the-hood, so that fragment metrics are not tracked for those elements
|
||||
(note that the @gather_metrics annotation is only on the publicly exposed function)
|
||||
"""
|
||||
|
||||
if func is None:
|
||||
# Support passing the params via function decorator
|
||||
def wrapper(f: F) -> F:
|
||||
return fragment(
|
||||
func=f,
|
||||
run_every=run_every,
|
||||
)
|
||||
|
||||
return wrapper
|
||||
else:
|
||||
non_optional_func = func
|
||||
|
||||
@wraps(non_optional_func)
|
||||
def wrap(*args, **kwargs):
|
||||
from streamlit.delta_generator_singletons import context_dg_stack
|
||||
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is None:
|
||||
return None
|
||||
|
||||
cursors_snapshot = deepcopy(ctx.cursors)
|
||||
dg_stack_snapshot = deepcopy(context_dg_stack.get())
|
||||
fragment_id = calc_md5(
|
||||
f"{non_optional_func.__module__}.{non_optional_func.__qualname__}{dg_stack_snapshot[-1]._get_delta_path_str()}{additional_hash_info}"
|
||||
)
|
||||
|
||||
# We intentionally want to capture the active script hash here to ensure
|
||||
# that the fragment is associated with the correct script running.
|
||||
initialized_active_script_hash = ctx.active_script_hash
|
||||
|
||||
def wrapped_fragment():
|
||||
import streamlit as st
|
||||
|
||||
if should_show_deprecation_warning:
|
||||
show_deprecation_warning(
|
||||
make_deprecated_name_warning(
|
||||
"experimental_fragment",
|
||||
"fragment",
|
||||
"2025-01-01",
|
||||
)
|
||||
)
|
||||
|
||||
# NOTE: We need to call get_script_run_ctx here again and can't just use the
|
||||
# value of ctx from above captured by the closure because subsequent
|
||||
# fragment runs will generally run in a new script run, thus we'll have a
|
||||
# new ctx.
|
||||
ctx = get_script_run_ctx(suppress_warning=True)
|
||||
assert ctx is not None
|
||||
|
||||
if ctx.fragment_ids_this_run:
|
||||
# This script run is a run of one or more fragments. We restore the
|
||||
# state of ctx.cursors and dg_stack to the snapshots we took when this
|
||||
# fragment was declared.
|
||||
ctx.cursors = deepcopy(cursors_snapshot)
|
||||
context_dg_stack.set(deepcopy(dg_stack_snapshot))
|
||||
|
||||
# Always add the fragment id to new_fragment_ids. For full app runs
|
||||
# we need to add them anyways and for fragment runs we add them
|
||||
# in case the to-be-executed fragment id was cleared from the storage
|
||||
# by the full app run.
|
||||
ctx.new_fragment_ids.add(fragment_id)
|
||||
# Set ctx.current_fragment_id so that elements corresponding to this
|
||||
# fragment get tagged with the appropriate ID. ctx.current_fragment_id gets
|
||||
# reset after the fragment function finishes running to either return to the
|
||||
# script (outside of any fragments) or to the outer fragment this one is
|
||||
# nested in.
|
||||
prev_fragment_id = ctx.current_fragment_id
|
||||
ctx.current_fragment_id = fragment_id
|
||||
|
||||
try:
|
||||
# Make sure we set the active script hash to the same value
|
||||
# for the fragment run as when defined upon initialization
|
||||
# This ensures that elements (especially widgets) are tied
|
||||
# to a consistent active script hash
|
||||
active_hash_context = (
|
||||
ctx.run_with_active_hash(initialized_active_script_hash)
|
||||
if initialized_active_script_hash != ctx.active_script_hash
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
result = None
|
||||
with active_hash_context:
|
||||
with st.container():
|
||||
try:
|
||||
# use dg_stack instead of active_dg to have correct copy
|
||||
# during execution (otherwise we can run into concurrency
|
||||
# issues with multiple fragments). Use dg_stack because we
|
||||
# just entered a container and [:-1] of the delta path
|
||||
# because thats the prefix of the fragment,
|
||||
# e.g. [0, 3, 0] -> [0, 3].
|
||||
# All fragment elements start with [0, 3].
|
||||
active_dg = context_dg_stack.get()[-1]
|
||||
ctx.current_fragment_delta_path = (
|
||||
active_dg._cursor.delta_path
|
||||
if active_dg._cursor
|
||||
else []
|
||||
)[:-1]
|
||||
result = non_optional_func(*args, **kwargs)
|
||||
except (
|
||||
RerunException,
|
||||
StopException,
|
||||
) as e:
|
||||
# The wrapped_fragment function is executed
|
||||
# inside of a exec_func_with_error_handling call, so
|
||||
# there is a correct handler for these exceptions.
|
||||
raise e
|
||||
except Exception as e:
|
||||
# render error here so that the delta path is correct
|
||||
# for full app runs, the error will be displayed by the
|
||||
# main code handler
|
||||
# if not is_full_app_run:
|
||||
handle_uncaught_app_exception(e)
|
||||
# raise here again in case we are in full app execution
|
||||
# and some flags have to be set
|
||||
raise FragmentHandledException(e)
|
||||
return result
|
||||
finally:
|
||||
ctx.current_fragment_id = prev_fragment_id
|
||||
ctx.current_fragment_delta_path = []
|
||||
|
||||
ctx.fragment_storage.set(fragment_id, wrapped_fragment)
|
||||
|
||||
if run_every:
|
||||
msg = ForwardMsg()
|
||||
msg.auto_rerun.interval = time_to_seconds(run_every)
|
||||
msg.auto_rerun.fragment_id = fragment_id
|
||||
ctx.enqueue(msg)
|
||||
|
||||
# Immediate execute the wrapped fragment since we are in a full app run
|
||||
return wrapped_fragment()
|
||||
|
||||
with contextlib.suppress(AttributeError):
|
||||
# Make this a well-behaved decorator by preserving important function
|
||||
# attributes.
|
||||
wrap.__dict__.update(non_optional_func.__dict__)
|
||||
wrap.__signature__ = inspect.signature(non_optional_func) # type: ignore
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
@overload
|
||||
def fragment(
|
||||
func: F,
|
||||
*,
|
||||
run_every: int | float | timedelta | str | None = None,
|
||||
) -> F: ...
|
||||
|
||||
|
||||
# Support being able to pass parameters to this decorator (that is, being able to write
|
||||
# `@fragment(run_every=5.0)`).
|
||||
@overload
|
||||
def fragment(
|
||||
func: None = None,
|
||||
*,
|
||||
run_every: int | float | timedelta | str | None = None,
|
||||
) -> Callable[[F], F]: ...
|
||||
|
||||
|
||||
@gather_metrics("fragment")
|
||||
def fragment(
|
||||
func: F | None = None,
|
||||
*,
|
||||
run_every: int | float | timedelta | str | None = None,
|
||||
) -> Callable[[F], F] | F:
|
||||
"""Decorator to turn a function into a fragment which can rerun independently\
|
||||
of the full app.
|
||||
|
||||
When a user interacts with an input widget created inside a fragment,
|
||||
Streamlit only reruns the fragment instead of the full app. If
|
||||
``run_every`` is set, Streamlit will also rerun the fragment at the
|
||||
specified interval while the session is active, even if the user is not
|
||||
interacting with your app.
|
||||
|
||||
To trigger an app rerun from inside a fragment, call ``st.rerun()``
|
||||
directly. To trigger a fragment rerun from within itself, call
|
||||
``st.rerun(scope="fragment")``. Any values from the fragment that need to
|
||||
be accessed from the wider app should generally be stored in Session State.
|
||||
|
||||
When Streamlit element commands are called directly in a fragment, the
|
||||
elements are cleared and redrawn on each fragment rerun, just like all
|
||||
elements are redrawn on each app rerun. The rest of the app is persisted
|
||||
during a fragment rerun. When a fragment renders elements into externally
|
||||
created containers, the elements will not be cleared with each fragment
|
||||
rerun. Instead, elements will accumulate in those containers with each
|
||||
fragment rerun, until the next app rerun.
|
||||
|
||||
Calling ``st.sidebar`` in a fragment is not supported. To write elements to
|
||||
the sidebar with a fragment, call your fragment function inside a
|
||||
``with st.sidebar`` context manager.
|
||||
|
||||
Fragment code can interact with Session State, imported modules, and
|
||||
other Streamlit elements created outside the fragment. Note that these
|
||||
interactions are additive across multiple fragment reruns. You are
|
||||
responsible for handling any side effects of that behavior.
|
||||
|
||||
.. warning::
|
||||
|
||||
- Fragments can only contain widgets in their main body. Fragments
|
||||
can't render widgets to externally created containers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func: callable
|
||||
The function to turn into a fragment.
|
||||
|
||||
run_every: int, float, timedelta, str, or None
|
||||
The time interval between automatic fragment reruns. This can be one of
|
||||
the following:
|
||||
|
||||
- ``None`` (default).
|
||||
- An ``int`` or ``float`` specifying the interval in seconds.
|
||||
- A string specifying the time in a format supported by `Pandas'
|
||||
Timedelta constructor <https://pandas.pydata.org/docs/reference/api/pandas.Timedelta.html>`_,
|
||||
e.g. ``"1d"``, ``"1.5 days"``, or ``"1h23s"``.
|
||||
- A ``timedelta`` object from `Python's built-in datetime library
|
||||
<https://docs.python.org/3/library/datetime.html#timedelta-objects>`_,
|
||||
e.g. ``timedelta(days=1)``.
|
||||
|
||||
If ``run_every`` is ``None``, the fragment will only rerun from
|
||||
user-triggered events.
|
||||
|
||||
Examples
|
||||
--------
|
||||
The following example demonstrates basic usage of
|
||||
``@st.fragment``. As an analogy, "inflating balloons" is a slow process that happens
|
||||
outside of the fragment. "Releasing balloons" is a quick process that happens inside
|
||||
of the fragment.
|
||||
|
||||
>>> import streamlit as st
|
||||
>>> import time
|
||||
>>>
|
||||
>>> @st.fragment
|
||||
>>> def release_the_balloons():
|
||||
>>> st.button("Release the balloons", help="Fragment rerun")
|
||||
>>> st.balloons()
|
||||
>>>
|
||||
>>> with st.spinner("Inflating balloons..."):
|
||||
>>> time.sleep(5)
|
||||
>>> release_the_balloons()
|
||||
>>> st.button("Inflate more balloons", help="Full rerun")
|
||||
|
||||
.. output::
|
||||
https://doc-fragment-balloons.streamlit.app/
|
||||
height: 220px
|
||||
|
||||
This next example demonstrates how elements both inside and outside of a
|
||||
fragement update with each app or fragment rerun. In this app, clicking
|
||||
"Rerun full app" will increment both counters and update all values
|
||||
displayed in the app. In contrast, clicking "Rerun fragment" will only
|
||||
increment the counter within the fragment. In this case, the ``st.write``
|
||||
command inside the fragment will update the app's frontend, but the two
|
||||
``st.write`` commands outside the fragment will not update the frontend.
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> if "app_runs" not in st.session_state:
|
||||
>>> st.session_state.app_runs = 0
|
||||
>>> st.session_state.fragment_runs = 0
|
||||
>>>
|
||||
>>> @st.fragment
|
||||
>>> def my_fragment():
|
||||
>>> st.session_state.fragment_runs += 1
|
||||
>>> st.button("Rerun fragment")
|
||||
>>> st.write(f"Fragment says it ran {st.session_state.fragment_runs} times.")
|
||||
>>>
|
||||
>>> st.session_state.app_runs += 1
|
||||
>>> my_fragment()
|
||||
>>> st.button("Rerun full app")
|
||||
>>> st.write(f"Full app says it ran {st.session_state.app_runs} times.")
|
||||
>>> st.write(f"Full app sees that fragment ran {st.session_state.fragment_runs} times.")
|
||||
|
||||
.. output::
|
||||
https://doc-fragment.streamlit.app/
|
||||
height: 400px
|
||||
|
||||
You can also trigger an app rerun from inside a fragment by calling
|
||||
``st.rerun``.
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> if "clicks" not in st.session_state:
|
||||
>>> st.session_state.clicks = 0
|
||||
>>>
|
||||
>>> @st.fragment
|
||||
>>> def count_to_five():
|
||||
>>> if st.button("Plus one!"):
|
||||
>>> st.session_state.clicks += 1
|
||||
>>> if st.session_state.clicks % 5 == 0:
|
||||
>>> st.rerun()
|
||||
>>> return
|
||||
>>>
|
||||
>>> count_to_five()
|
||||
>>> st.header(f"Multiples of five clicks: {st.session_state.clicks // 5}")
|
||||
>>>
|
||||
>>> if st.button("Check click count"):
|
||||
>>> st.toast(f"## Total clicks: {st.session_state.clicks}")
|
||||
|
||||
.. output::
|
||||
https://doc-fragment-rerun.streamlit.app/
|
||||
height: 400px
|
||||
|
||||
"""
|
||||
return _fragment(func, run_every=run_every)
|
||||
|
||||
|
||||
@overload
|
||||
def experimental_fragment(
|
||||
func: F,
|
||||
*,
|
||||
run_every: int | float | timedelta | str | None = None,
|
||||
) -> F: ...
|
||||
|
||||
|
||||
# Support being able to pass parameters to this decorator (that is, being able to write
|
||||
# `@fragment(run_every=5.0)`).
|
||||
@overload
|
||||
def experimental_fragment(
|
||||
func: None = None,
|
||||
*,
|
||||
run_every: int | float | timedelta | str | None = None,
|
||||
) -> Callable[[F], F]: ...
|
||||
|
||||
|
||||
@gather_metrics("experimental_fragment")
|
||||
def experimental_fragment(
|
||||
func: F | None = None,
|
||||
*,
|
||||
run_every: int | float | timedelta | str | None = None,
|
||||
) -> Callable[[F], F] | F:
|
||||
"""Deprecated alias for @st.fragment. See the docstring for the decorator's new name."""
|
||||
return _fragment(func, run_every=run_every, should_show_deprecation_warning=True)
|
||||
@@ -0,0 +1,234 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Provides global MediaFileManager object as `media_file_manager`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import threading
|
||||
from typing import Final
|
||||
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.media_file_storage import MediaFileKind, MediaFileStorage
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
|
||||
def _get_session_id() -> str:
|
||||
"""Get the active AppSession's session_id."""
|
||||
from streamlit.runtime.scriptrunner_utils.script_run_context import (
|
||||
get_script_run_ctx,
|
||||
)
|
||||
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is None:
|
||||
# This is only None when running "python myscript.py" rather than
|
||||
# "streamlit run myscript.py". In which case the session ID doesn't
|
||||
# matter and can just be a constant, as there's only ever "session".
|
||||
return "dontcare"
|
||||
else:
|
||||
return ctx.session_id
|
||||
|
||||
|
||||
class MediaFileMetadata:
|
||||
"""Metadata that the MediaFileManager needs for each file it manages."""
|
||||
|
||||
def __init__(self, kind: MediaFileKind = MediaFileKind.MEDIA):
|
||||
self._kind = kind
|
||||
self._is_marked_for_delete = False
|
||||
|
||||
@property
|
||||
def kind(self) -> MediaFileKind:
|
||||
return self._kind
|
||||
|
||||
@property
|
||||
def is_marked_for_delete(self) -> bool:
|
||||
return self._is_marked_for_delete
|
||||
|
||||
def mark_for_delete(self) -> None:
|
||||
self._is_marked_for_delete = True
|
||||
|
||||
|
||||
class MediaFileManager:
|
||||
"""In-memory file manager for MediaFile objects.
|
||||
|
||||
This keeps track of:
|
||||
- Which files exist, and what their IDs are. This is important so we can
|
||||
serve files by ID -- that's the whole point of this class!
|
||||
- Which files are being used by which AppSession (by ID). This is
|
||||
important so we can remove files from memory when no more sessions need
|
||||
them.
|
||||
- The exact location in the app where each file is being used (i.e. the
|
||||
file's "coordinates"). This is is important so we can mark a file as "not
|
||||
being used by a certain session" if it gets replaced by another file at
|
||||
the same coordinates. For example, when doing an animation where the same
|
||||
image is constantly replace with new frames. (This doesn't solve the case
|
||||
where the file's coordinates keep changing for some reason, though! e.g.
|
||||
if new elements keep being prepended to the app. Unlikely to happen, but
|
||||
we should address it at some point.)
|
||||
"""
|
||||
|
||||
def __init__(self, storage: MediaFileStorage):
|
||||
self._storage = storage
|
||||
|
||||
# Dict of [file_id -> MediaFileMetadata]
|
||||
self._file_metadata: dict[str, MediaFileMetadata] = {}
|
||||
|
||||
# Dict[session ID][coordinates] -> file_id.
|
||||
self._files_by_session_and_coord: dict[str, dict[str, str]] = (
|
||||
collections.defaultdict(dict)
|
||||
)
|
||||
|
||||
# MediaFileManager is used from multiple threads, so all operations
|
||||
# need to be protected with a Lock. (This is not an RLock, which
|
||||
# means taking it multiple times from the same thread will deadlock.)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _get_inactive_file_ids(self) -> set[str]:
|
||||
"""Compute the set of files that are stored in the manager, but are
|
||||
not referenced by any active session. These are files that can be
|
||||
safely deleted.
|
||||
|
||||
Thread safety: callers must hold `self._lock`.
|
||||
"""
|
||||
# Get the set of all our file IDs.
|
||||
file_ids = set(self._file_metadata.keys())
|
||||
|
||||
# Subtract all IDs that are in use by each session
|
||||
for session_file_ids_by_coord in self._files_by_session_and_coord.values():
|
||||
file_ids.difference_update(session_file_ids_by_coord.values())
|
||||
|
||||
return file_ids
|
||||
|
||||
def remove_orphaned_files(self) -> None:
|
||||
"""Remove all files that are no longer referenced by any active session.
|
||||
|
||||
Safe to call from any thread.
|
||||
"""
|
||||
_LOGGER.debug("Removing orphaned files...")
|
||||
|
||||
with self._lock:
|
||||
for file_id in self._get_inactive_file_ids():
|
||||
file = self._file_metadata[file_id]
|
||||
if file.kind == MediaFileKind.MEDIA:
|
||||
self._delete_file(file_id)
|
||||
elif file.kind == MediaFileKind.DOWNLOADABLE:
|
||||
if file.is_marked_for_delete:
|
||||
self._delete_file(file_id)
|
||||
else:
|
||||
file.mark_for_delete()
|
||||
|
||||
def _delete_file(self, file_id: str) -> None:
|
||||
"""Delete the given file from storage, and remove its metadata from
|
||||
self._files_by_id.
|
||||
|
||||
Thread safety: callers must hold `self._lock`.
|
||||
"""
|
||||
_LOGGER.debug("Deleting File: %s", file_id)
|
||||
self._storage.delete_file(file_id)
|
||||
del self._file_metadata[file_id]
|
||||
|
||||
def clear_session_refs(self, session_id: str | None = None) -> None:
|
||||
"""Remove the given session's file references.
|
||||
|
||||
(This does not remove any files from the manager - you must call
|
||||
`remove_orphaned_files` for that.)
|
||||
|
||||
Should be called whenever ScriptRunner completes and when a session ends.
|
||||
|
||||
Safe to call from any thread.
|
||||
"""
|
||||
if session_id is None:
|
||||
session_id = _get_session_id()
|
||||
|
||||
_LOGGER.debug("Disconnecting files for session with ID %s", session_id)
|
||||
|
||||
with self._lock:
|
||||
if session_id in self._files_by_session_and_coord:
|
||||
del self._files_by_session_and_coord[session_id]
|
||||
|
||||
_LOGGER.debug(
|
||||
"Sessions still active: %r", self._files_by_session_and_coord.keys()
|
||||
)
|
||||
|
||||
_LOGGER.debug(
|
||||
"Files: %s; Sessions with files: %s",
|
||||
len(self._file_metadata),
|
||||
len(self._files_by_session_and_coord),
|
||||
)
|
||||
|
||||
def add(
|
||||
self,
|
||||
path_or_data: bytes | str,
|
||||
mimetype: str,
|
||||
coordinates: str,
|
||||
file_name: str | None = None,
|
||||
is_for_static_download: bool = False,
|
||||
) -> str:
|
||||
"""Add a new MediaFile with the given parameters and return its URL.
|
||||
|
||||
If an identical file already exists, return the existing URL
|
||||
and registers the current session as a user.
|
||||
|
||||
Safe to call from any thread.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path_or_data : bytes or str
|
||||
If bytes: the media file's raw data. If str: the name of a file
|
||||
to load from disk.
|
||||
mimetype : str
|
||||
The mime type for the file. E.g. "audio/mpeg".
|
||||
This string will be used in the "Content-Type" header when the file
|
||||
is served over HTTP.
|
||||
coordinates : str
|
||||
Unique string identifying an element's location.
|
||||
Prevents memory leak of "forgotten" file IDs when element media
|
||||
is being replaced-in-place (e.g. an st.image stream).
|
||||
coordinates should be of the form: "1.(3.-14).5"
|
||||
file_name : str or None
|
||||
Optional file_name. Used to set the filename in the response header.
|
||||
is_for_static_download: bool
|
||||
Indicate that data stored for downloading as a file,
|
||||
not as a media for rendering at page. [default: False]
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The url that the frontend can use to fetch the media.
|
||||
|
||||
Raises
|
||||
------
|
||||
If a filename is passed, any Exception raised when trying to read the
|
||||
file will be re-raised.
|
||||
"""
|
||||
|
||||
session_id = _get_session_id()
|
||||
|
||||
with self._lock:
|
||||
kind = (
|
||||
MediaFileKind.DOWNLOADABLE
|
||||
if is_for_static_download
|
||||
else MediaFileKind.MEDIA
|
||||
)
|
||||
file_id = self._storage.load_and_get_id(
|
||||
path_or_data, mimetype, kind, file_name
|
||||
)
|
||||
metadata = MediaFileMetadata(kind=kind)
|
||||
|
||||
self._file_metadata[file_id] = metadata
|
||||
self._files_by_session_and_coord[session_id][coordinates] = file_id
|
||||
|
||||
return self._storage.get_url(file_id)
|
||||
@@ -0,0 +1,143 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class MediaFileKind(Enum):
|
||||
# st.image, st.video, st.audio files
|
||||
MEDIA = "media"
|
||||
|
||||
# st.download_button files
|
||||
DOWNLOADABLE = "downloadable"
|
||||
|
||||
|
||||
class MediaFileStorageError(Exception):
|
||||
"""Exception class for errors raised by MediaFileStorage.
|
||||
|
||||
When running in "development mode", the full text of these errors
|
||||
is displayed in the frontend, so errors should be human-readable
|
||||
(and actionable).
|
||||
|
||||
When running in "release mode", errors are redacted on the
|
||||
frontend; we instead show a generic "Something went wrong!" message.
|
||||
"""
|
||||
|
||||
|
||||
class MediaFileStorage(Protocol):
|
||||
@abstractmethod
|
||||
def load_and_get_id(
|
||||
self,
|
||||
path_or_data: str | bytes,
|
||||
mimetype: str,
|
||||
kind: MediaFileKind,
|
||||
filename: str | None = None,
|
||||
) -> str:
|
||||
"""Load the given file path or bytes into the manager and return
|
||||
an ID that uniquely identifies it.
|
||||
|
||||
It's an error to pass a URL to this function. (Media stored at
|
||||
external URLs can be served directly to the Streamlit frontend;
|
||||
there's no need to store this data in MediaFileStorage.)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path_or_data
|
||||
A path to a file, or the file's raw data as bytes.
|
||||
|
||||
mimetype
|
||||
The media's mimetype. Used to set the Content-Type header when
|
||||
serving the media over HTTP.
|
||||
|
||||
kind
|
||||
The kind of file this is: either MEDIA, or DOWNLOADABLE.
|
||||
|
||||
filename : str or None
|
||||
Optional filename. Used to set the filename in the response header.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The unique ID of the media file.
|
||||
|
||||
Raises
|
||||
------
|
||||
MediaFileStorageError
|
||||
Raised if the media can't be loaded (for example, if a file
|
||||
path is invalid).
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, file_id: str) -> str:
|
||||
"""Return a URL for a file in the manager.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_id
|
||||
The file's ID, returned from load_media_and_get_id().
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
A URL that the frontend can load the file from. Because this
|
||||
URL may expire, it should not be cached!
|
||||
|
||||
Raises
|
||||
------
|
||||
MediaFileStorageError
|
||||
Raised if the manager doesn't contain an object with the given ID.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(self, file_id: str) -> None:
|
||||
"""Delete a file from the manager.
|
||||
|
||||
This should be called when a given file is no longer referenced
|
||||
by any connected client, so that the MediaFileStorage can free its
|
||||
resources.
|
||||
|
||||
Calling `delete_file` on a file_id that doesn't exist is allowed,
|
||||
and is a no-op. (This means that multiple `delete_file` calls with
|
||||
the same file_id is not an error.)
|
||||
|
||||
Note: implementations can choose to ignore `delete_file` calls -
|
||||
this function is a *suggestion*, not a *command*. Callers should
|
||||
not rely on file deletion happening immediately (or at all).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_id
|
||||
The file's ID, returned from load_media_and_get_id().
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
|
||||
Raises
|
||||
------
|
||||
MediaFileStorageError
|
||||
Raised if file deletion fails for any reason. Note that these
|
||||
failures will generally not be shown on the frontend (file
|
||||
deletion usually occurs on session disconnect).
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,181 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MediaFileStorage implementation that stores files in memory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import mimetypes
|
||||
import os.path
|
||||
from typing import Final, NamedTuple
|
||||
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.media_file_storage import (
|
||||
MediaFileKind,
|
||||
MediaFileStorage,
|
||||
MediaFileStorageError,
|
||||
)
|
||||
from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
# Mimetype -> filename extension map for the `get_extension_for_mimetype`
|
||||
# function. We use Python's `mimetypes.guess_extension` for most mimetypes,
|
||||
# but (as of Python 3.9) `mimetypes.guess_extension("audio/wav")` returns None,
|
||||
# so we handle it ourselves.
|
||||
PREFERRED_MIMETYPE_EXTENSION_MAP: Final = {
|
||||
"audio/wav": ".wav",
|
||||
"text/vtt": ".vtt",
|
||||
}
|
||||
|
||||
|
||||
def _calculate_file_id(data: bytes, mimetype: str, filename: str | None = None) -> str:
|
||||
"""Hash data, mimetype, and an optional filename to generate a stable file ID.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data
|
||||
Content of in-memory file in bytes. Other types will throw TypeError.
|
||||
mimetype
|
||||
Any string. Will be converted to bytes and used to compute a hash.
|
||||
filename
|
||||
Any string. Will be converted to bytes and used to compute a hash.
|
||||
"""
|
||||
filehash = hashlib.new("sha224", usedforsecurity=False)
|
||||
filehash.update(data)
|
||||
filehash.update(bytes(mimetype.encode()))
|
||||
|
||||
if filename is not None:
|
||||
filehash.update(bytes(filename.encode()))
|
||||
|
||||
return filehash.hexdigest()
|
||||
|
||||
|
||||
def get_extension_for_mimetype(mimetype: str) -> str:
|
||||
if mimetype in PREFERRED_MIMETYPE_EXTENSION_MAP:
|
||||
return PREFERRED_MIMETYPE_EXTENSION_MAP[mimetype]
|
||||
|
||||
extension = mimetypes.guess_extension(mimetype, strict=False)
|
||||
if extension is None:
|
||||
return ""
|
||||
|
||||
return extension
|
||||
|
||||
|
||||
class MemoryFile(NamedTuple):
|
||||
"""A MediaFile stored in memory."""
|
||||
|
||||
content: bytes
|
||||
mimetype: str
|
||||
kind: MediaFileKind
|
||||
filename: str | None
|
||||
|
||||
@property
|
||||
def content_size(self) -> int:
|
||||
return len(self.content)
|
||||
|
||||
|
||||
class MemoryMediaFileStorage(MediaFileStorage, CacheStatsProvider):
|
||||
def __init__(self, media_endpoint: str):
|
||||
"""Create a new MemoryMediaFileStorage instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
media_endpoint
|
||||
The name of the local endpoint that media is served from.
|
||||
This endpoint should start with a forward-slash (e.g. "/media").
|
||||
"""
|
||||
self._files_by_id: dict[str, MemoryFile] = {}
|
||||
self._media_endpoint = media_endpoint
|
||||
|
||||
def load_and_get_id(
|
||||
self,
|
||||
path_or_data: str | bytes,
|
||||
mimetype: str,
|
||||
kind: MediaFileKind,
|
||||
filename: str | None = None,
|
||||
) -> str:
|
||||
"""Add a file to the manager and return its ID."""
|
||||
file_data: bytes
|
||||
if isinstance(path_or_data, str):
|
||||
file_data = self._read_file(path_or_data)
|
||||
else:
|
||||
file_data = path_or_data
|
||||
|
||||
# Because our file_ids are stable, if we already have a file with the
|
||||
# given ID, we don't need to create a new one.
|
||||
file_id = _calculate_file_id(file_data, mimetype, filename)
|
||||
if file_id not in self._files_by_id:
|
||||
_LOGGER.debug("Adding media file %s", file_id)
|
||||
media_file = MemoryFile(
|
||||
content=file_data, mimetype=mimetype, kind=kind, filename=filename
|
||||
)
|
||||
self._files_by_id[file_id] = media_file
|
||||
|
||||
return file_id
|
||||
|
||||
def get_file(self, filename: str) -> MemoryFile:
|
||||
"""Return the MemoryFile with the given filename. Filenames are of the
|
||||
form "file_id.extension". (Note that this is *not* the optional
|
||||
user-specified filename for download files.).
|
||||
|
||||
Raises a MediaFileStorageError if no such file exists.
|
||||
"""
|
||||
file_id = os.path.splitext(filename)[0]
|
||||
try:
|
||||
return self._files_by_id[file_id]
|
||||
except KeyError as e:
|
||||
raise MediaFileStorageError(
|
||||
f"Bad filename '{filename}'. (No media file with id '{file_id}')"
|
||||
) from e
|
||||
|
||||
def get_url(self, file_id: str) -> str:
|
||||
"""Get a URL for a given media file. Raise a MediaFileStorageError if
|
||||
no such file exists.
|
||||
"""
|
||||
media_file = self.get_file(file_id)
|
||||
extension = get_extension_for_mimetype(media_file.mimetype)
|
||||
return f"{self._media_endpoint}/{file_id}{extension}"
|
||||
|
||||
def delete_file(self, file_id: str) -> None:
|
||||
"""Delete the file with the given ID."""
|
||||
# We swallow KeyErrors here - it's not an error to delete a file
|
||||
# that doesn't exist.
|
||||
with contextlib.suppress(KeyError):
|
||||
del self._files_by_id[file_id]
|
||||
|
||||
def _read_file(self, filename: str) -> bytes:
|
||||
"""Read a file into memory. Raise MediaFileStorageError if we can't."""
|
||||
try:
|
||||
with open(filename, "rb") as f:
|
||||
return f.read()
|
||||
except Exception as ex:
|
||||
raise MediaFileStorageError(f"Error opening '{filename}'") from ex
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
# We operate on a copy of our dict, to avoid race conditions
|
||||
# with other threads that may be manipulating the cache.
|
||||
files_by_id = self._files_by_id.copy()
|
||||
|
||||
stats: list[CacheStat] = [
|
||||
CacheStat(
|
||||
category_name="st_memory_media_file_storage",
|
||||
cache_name="",
|
||||
byte_length=len(file.content),
|
||||
)
|
||||
for _, file in files_by_id.items()
|
||||
]
|
||||
return group_stats(stats)
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
from streamlit.runtime.session_manager import SessionInfo, SessionStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import MutableMapping
|
||||
|
||||
|
||||
class MemorySessionStorage(SessionStorage):
|
||||
"""A SessionStorage that stores sessions in memory.
|
||||
|
||||
At most maxsize sessions are stored with a TTL of ttl seconds. This class is really
|
||||
just a thin wrapper around cachetools.TTLCache that complies with the SessionStorage
|
||||
protocol.
|
||||
"""
|
||||
|
||||
# NOTE: The defaults for maxsize and ttl are chosen arbitrarily for now. These
|
||||
# numbers are reasonable as the main problems we're trying to solve at the moment are
|
||||
# caused by transient disconnects that are usually just short network blips. In the
|
||||
# future, we may want to increase both to support use cases such as saving state for
|
||||
# much longer periods of time. For example, we may want session state to persist if
|
||||
# a user closes their laptop lid and comes back to an app hours later.
|
||||
def __init__(
|
||||
self,
|
||||
maxsize: int = 128,
|
||||
ttl_seconds: int = 2 * 60, # 2 minutes
|
||||
) -> None:
|
||||
"""Instantiate a new MemorySessionStorage.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
maxsize
|
||||
The maximum number of sessions we allow to be stored in this
|
||||
MemorySessionStorage. If an entry needs to be removed because we have
|
||||
exceeded this number, either
|
||||
- an expired entry is removed, or
|
||||
- the least recently used entry is removed (if no entries have expired).
|
||||
|
||||
ttl_seconds
|
||||
The time in seconds for an entry added to a MemorySessionStorage to live.
|
||||
After this amount of time has passed for a given entry, it becomes
|
||||
inaccessible and will be removed eventually.
|
||||
"""
|
||||
|
||||
self._cache: MutableMapping[str, SessionInfo] = TTLCache(
|
||||
maxsize=maxsize, ttl=ttl_seconds
|
||||
)
|
||||
|
||||
def get(self, session_id: str) -> SessionInfo | None:
|
||||
return self._cache.get(session_id, None)
|
||||
|
||||
def save(self, session_info: SessionInfo) -> None:
|
||||
self._cache[session_info.session.id] = session_info
|
||||
|
||||
def delete(self, session_id: str) -> None:
|
||||
del self._cache[session_id]
|
||||
|
||||
def list(self) -> list[SessionInfo]:
|
||||
return list(self._cache.values())
|
||||
@@ -0,0 +1,138 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from streamlit import util
|
||||
from streamlit.runtime.stats import CacheStat, group_stats
|
||||
from streamlit.runtime.uploaded_file_manager import (
|
||||
UploadedFileManager,
|
||||
UploadedFileRec,
|
||||
UploadFileUrlInfo,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
class MemoryUploadedFileManager(UploadedFileManager):
|
||||
"""Holds files uploaded by users of the running Streamlit app.
|
||||
This class can be used safely from multiple threads simultaneously.
|
||||
"""
|
||||
|
||||
def __init__(self, upload_endpoint: str):
|
||||
self.file_storage: dict[str, dict[str, UploadedFileRec]] = defaultdict(dict)
|
||||
self.endpoint = upload_endpoint
|
||||
|
||||
def get_files(
|
||||
self, session_id: str, file_ids: Sequence[str]
|
||||
) -> list[UploadedFileRec]:
|
||||
"""Return a list of UploadedFileRec for a given sequence of file_ids.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The ID of the session that owns the files.
|
||||
file_ids
|
||||
The sequence of ids associated with files to retrieve.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[UploadedFileRec]
|
||||
A list of URL UploadedFileRec instances, each instance contains information
|
||||
about uploaded file.
|
||||
"""
|
||||
session_storage = self.file_storage[session_id]
|
||||
file_recs = []
|
||||
|
||||
for file_id in file_ids:
|
||||
file_rec = session_storage.get(file_id, None)
|
||||
if file_rec is not None:
|
||||
file_recs.append(file_rec)
|
||||
|
||||
return file_recs
|
||||
|
||||
def remove_session_files(self, session_id: str) -> None:
|
||||
"""Remove all files associated with a given session."""
|
||||
self.file_storage.pop(session_id, None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
def add_file(
|
||||
self,
|
||||
session_id: str,
|
||||
file: UploadedFileRec,
|
||||
) -> None:
|
||||
"""
|
||||
Safe to call from any thread.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The ID of the session that owns the file.
|
||||
file
|
||||
The file to add.
|
||||
"""
|
||||
|
||||
self.file_storage[session_id][file.file_id] = file
|
||||
|
||||
def remove_file(self, session_id, file_id):
|
||||
"""Remove file with given file_id associated with a given session."""
|
||||
session_storage = self.file_storage[session_id]
|
||||
session_storage.pop(file_id, None)
|
||||
|
||||
def get_upload_urls(
|
||||
self, session_id: str, file_names: Sequence[str]
|
||||
) -> list[UploadFileUrlInfo]:
|
||||
"""Return a list of UploadFileUrlInfo for a given sequence of file_names."""
|
||||
result = []
|
||||
for _ in file_names:
|
||||
file_id = str(uuid.uuid4())
|
||||
result.append(
|
||||
UploadFileUrlInfo(
|
||||
file_id=file_id,
|
||||
upload_url=f"{self.endpoint}/{session_id}/{file_id}",
|
||||
delete_url=f"{self.endpoint}/{session_id}/{file_id}",
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
"""Return the manager's CacheStats.
|
||||
|
||||
Safe to call from any thread.
|
||||
"""
|
||||
# Flatten all files into a single list
|
||||
all_files: list[UploadedFileRec] = []
|
||||
# Make copy of self.file_storage for thread safety, to be sure
|
||||
# that main storage won't be changed form other thread
|
||||
file_storage_copy = self.file_storage.copy()
|
||||
|
||||
for session_storage in file_storage_copy.values():
|
||||
all_files.extend(session_storage.values())
|
||||
|
||||
stats: list[CacheStat] = [
|
||||
CacheStat(
|
||||
category_name="UploadedFileManager",
|
||||
cache_name="",
|
||||
byte_length=len(file.data),
|
||||
)
|
||||
for file in all_files
|
||||
]
|
||||
return group_stats(stats)
|
||||
@@ -0,0 +1,486 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Sized
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Final, TypeVar, cast, overload
|
||||
|
||||
from streamlit import config, util
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
from streamlit.proto.PageProfile_pb2 import Argument, Command
|
||||
from streamlit.runtime.scriptrunner_utils.exceptions import RerunException
|
||||
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
# Limit the number of commands to keep the page profile message small
|
||||
_MAX_TRACKED_COMMANDS: Final = 200
|
||||
# Only track a maximum of 25 uses per unique command since some apps use
|
||||
# commands excessively (e.g. calling add_rows thousands of times in one rerun)
|
||||
# making the page profile useless.
|
||||
_MAX_TRACKED_PER_COMMAND: Final = 25
|
||||
|
||||
# A mapping to convert from the actual name to preferred/shorter representations
|
||||
_OBJECT_NAME_MAPPING: Final = {
|
||||
"streamlit.delta_generator.DeltaGenerator": "DG",
|
||||
"pandas.core.frame.DataFrame": "DataFrame",
|
||||
"plotly.graph_objs._figure.Figure": "PlotlyFigure",
|
||||
"bokeh.plotting.figure.Figure": "BokehFigure",
|
||||
"matplotlib.figure.Figure": "MatplotlibFigure",
|
||||
"pandas.io.formats.style.Styler": "PandasStyler",
|
||||
"pandas.core.indexes.base.Index": "PandasIndex",
|
||||
"pandas.core.series.Series": "PandasSeries",
|
||||
"streamlit.connections.snowpark_connection.SnowparkConnection": "SnowparkConnection",
|
||||
"streamlit.connections.sql_connection.SQLConnection": "SQLConnection",
|
||||
}
|
||||
|
||||
# A list of dependencies to check for attribution
|
||||
_ATTRIBUTIONS_TO_CHECK: Final = [
|
||||
# DB Clients:
|
||||
"pymysql",
|
||||
"MySQLdb",
|
||||
"mysql",
|
||||
"pymongo",
|
||||
"ibis",
|
||||
"boto3",
|
||||
"psycopg2",
|
||||
"psycopg3",
|
||||
"sqlalchemy",
|
||||
"elasticsearch",
|
||||
"pyodbc",
|
||||
"pymssql",
|
||||
"cassandra",
|
||||
"azure",
|
||||
"redis",
|
||||
"sqlite3",
|
||||
"neo4j",
|
||||
"duckdb",
|
||||
"opensearchpy",
|
||||
"supabase",
|
||||
# Dataframe Libraries:
|
||||
"polars",
|
||||
"dask",
|
||||
"vaex",
|
||||
"modin",
|
||||
"pyspark",
|
||||
"cudf",
|
||||
"xarray",
|
||||
"ray",
|
||||
"geopandas",
|
||||
"mars",
|
||||
"tables",
|
||||
"zarr",
|
||||
"datasets",
|
||||
# ML & LLM Tools:
|
||||
"mistralai",
|
||||
"openai",
|
||||
"langchain",
|
||||
"llama_index",
|
||||
"llama_cpp",
|
||||
"anthropic",
|
||||
"pyllamacpp",
|
||||
"cohere",
|
||||
"transformers",
|
||||
"nomic",
|
||||
"diffusers",
|
||||
"semantic_kernel",
|
||||
"replicate",
|
||||
"huggingface_hub",
|
||||
"wandb",
|
||||
"torch",
|
||||
"tensorflow",
|
||||
"trubrics",
|
||||
"comet_ml",
|
||||
"clarifai",
|
||||
"reka",
|
||||
"hegel",
|
||||
"fastchat",
|
||||
"assemblyai",
|
||||
"openllm",
|
||||
"embedchain",
|
||||
"haystack",
|
||||
"vllm",
|
||||
"alpa",
|
||||
"jinaai",
|
||||
"guidance",
|
||||
"litellm",
|
||||
"comet_llm",
|
||||
"instructor",
|
||||
"xgboost",
|
||||
"lightgbm",
|
||||
"catboost",
|
||||
"sklearn",
|
||||
# Workflow Tools:
|
||||
"prefect",
|
||||
"luigi",
|
||||
"airflow",
|
||||
"dagster",
|
||||
# Vector Stores:
|
||||
"pgvector",
|
||||
"faiss",
|
||||
"annoy",
|
||||
"pinecone",
|
||||
"chromadb",
|
||||
"weaviate",
|
||||
"qdrant_client",
|
||||
"pymilvus",
|
||||
"lancedb",
|
||||
# Others:
|
||||
"snowflake",
|
||||
"streamlit_extras",
|
||||
"streamlit_pydantic",
|
||||
"pydantic",
|
||||
"plost",
|
||||
"authlib",
|
||||
]
|
||||
|
||||
_ETC_MACHINE_ID_PATH = "/etc/machine-id"
|
||||
_DBUS_MACHINE_ID_PATH = "/var/lib/dbus/machine-id"
|
||||
|
||||
|
||||
def _get_machine_id_v3() -> str:
|
||||
"""Get the machine ID.
|
||||
|
||||
This is a unique identifier for a user for tracking metrics,
|
||||
that is broken in different ways in some Linux distros and Docker images.
|
||||
- at times just a hash of '', which means many machines map to the same ID
|
||||
- at times a hash of the same string, when running in a Docker container
|
||||
"""
|
||||
|
||||
machine_id = str(uuid.getnode())
|
||||
if os.path.isfile(_ETC_MACHINE_ID_PATH):
|
||||
with open(_ETC_MACHINE_ID_PATH) as f:
|
||||
machine_id = f.read()
|
||||
|
||||
elif os.path.isfile(_DBUS_MACHINE_ID_PATH):
|
||||
with open(_DBUS_MACHINE_ID_PATH) as f:
|
||||
machine_id = f.read()
|
||||
|
||||
return machine_id
|
||||
|
||||
|
||||
class Installation:
|
||||
_instance_lock = threading.Lock()
|
||||
_instance: Installation | None = None
|
||||
|
||||
@classmethod
|
||||
def instance(cls) -> Installation:
|
||||
"""Returns the singleton Installation."""
|
||||
# We use a double-checked locking optimization to avoid the overhead
|
||||
# of acquiring the lock in the common case:
|
||||
# https://en.wikipedia.org/wiki/Double-checked_locking
|
||||
if cls._instance is None:
|
||||
with cls._instance_lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = Installation()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
self.installation_id_v3 = str(
|
||||
uuid.uuid5(uuid.NAMESPACE_DNS, _get_machine_id_v3())
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
@property
|
||||
def installation_id(self):
|
||||
return self.installation_id_v3
|
||||
|
||||
|
||||
def _get_type_name(obj: object) -> str:
|
||||
"""Get a simplified name for the type of the given object."""
|
||||
with contextlib.suppress(Exception):
|
||||
obj_type = obj if inspect.isclass(obj) else type(obj)
|
||||
type_name = "unknown"
|
||||
if hasattr(obj_type, "__qualname__"):
|
||||
type_name = obj_type.__qualname__
|
||||
elif hasattr(obj_type, "__name__"):
|
||||
type_name = obj_type.__name__
|
||||
|
||||
if obj_type.__module__ != "builtins":
|
||||
# Add the full module path
|
||||
type_name = f"{obj_type.__module__}.{type_name}"
|
||||
|
||||
if type_name in _OBJECT_NAME_MAPPING:
|
||||
type_name = _OBJECT_NAME_MAPPING[type_name]
|
||||
return type_name
|
||||
return "failed"
|
||||
|
||||
|
||||
def _get_top_level_module(func: Callable[..., Any]) -> str:
|
||||
"""Get the top level module for the given function."""
|
||||
module = inspect.getmodule(func)
|
||||
if module is None or not module.__name__:
|
||||
return "unknown"
|
||||
return module.__name__.split(".")[0]
|
||||
|
||||
|
||||
def _get_arg_metadata(arg: object) -> str | None:
|
||||
"""Get metadata information related to the value of the given object."""
|
||||
with contextlib.suppress(Exception):
|
||||
if isinstance(arg, (bool)):
|
||||
return f"val:{arg}"
|
||||
|
||||
if isinstance(arg, Sized):
|
||||
return f"len:{len(arg)}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_command_telemetry(
|
||||
_command_func: Callable[..., Any], _command_name: str, *args, **kwargs
|
||||
) -> Command:
|
||||
"""Get telemetry information for the given callable and its arguments."""
|
||||
arg_keywords = inspect.getfullargspec(_command_func).args
|
||||
self_arg: Any | None = None
|
||||
arguments: list[Argument] = []
|
||||
is_method = inspect.ismethod(_command_func)
|
||||
name = _command_name
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
pos = i
|
||||
if is_method:
|
||||
# If func is a method, ignore the first argument (self)
|
||||
i = i + 1
|
||||
|
||||
keyword = arg_keywords[i] if len(arg_keywords) > i else f"{i}"
|
||||
if keyword == "self":
|
||||
self_arg = arg
|
||||
continue
|
||||
argument = Argument(k=keyword, t=_get_type_name(arg), p=pos)
|
||||
|
||||
arg_metadata = _get_arg_metadata(arg)
|
||||
if arg_metadata:
|
||||
argument.m = arg_metadata
|
||||
arguments.append(argument)
|
||||
for kwarg, kwarg_value in kwargs.items():
|
||||
argument = Argument(k=kwarg, t=_get_type_name(kwarg_value))
|
||||
|
||||
arg_metadata = _get_arg_metadata(kwarg_value)
|
||||
if arg_metadata:
|
||||
argument.m = arg_metadata
|
||||
arguments.append(argument)
|
||||
|
||||
top_level_module = _get_top_level_module(_command_func)
|
||||
if top_level_module != "streamlit":
|
||||
# If the gather_metrics decorator is used outside of streamlit library
|
||||
# we enforce a prefix to be added to the tracked command:
|
||||
name = f"external:{top_level_module}:{name}"
|
||||
|
||||
if (
|
||||
name == "create_instance"
|
||||
and self_arg
|
||||
and hasattr(self_arg, "name")
|
||||
and self_arg.name
|
||||
):
|
||||
name = f"component:{self_arg.name}"
|
||||
|
||||
return Command(name=name, args=arguments)
|
||||
|
||||
|
||||
def to_microseconds(seconds: float) -> int:
|
||||
"""Convert seconds into microseconds."""
|
||||
return int(seconds * 1_000_000)
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
@overload
|
||||
def gather_metrics(
|
||||
name: str,
|
||||
func: F,
|
||||
) -> F: ...
|
||||
|
||||
|
||||
@overload
|
||||
def gather_metrics(
|
||||
name: str,
|
||||
func: None = None,
|
||||
) -> Callable[[F], F]: ...
|
||||
|
||||
|
||||
def gather_metrics(name: str, func: F | None = None) -> Callable[[F], F] | F:
|
||||
"""Function decorator to add telemetry tracking to commands.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
The function to track for telemetry.
|
||||
|
||||
name : str or None
|
||||
Overwrite the function name with a custom name that is used for telemetry tracking.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> @st.gather_metrics
|
||||
... def my_command(url):
|
||||
... return url
|
||||
|
||||
>>> @st.gather_metrics(name="custom_name")
|
||||
... def my_command(url):
|
||||
... return url
|
||||
"""
|
||||
|
||||
if not name:
|
||||
_LOGGER.warning("gather_metrics: name is empty")
|
||||
name = "undefined"
|
||||
|
||||
if func is None:
|
||||
# Support passing the params via function decorator
|
||||
def wrapper(f: F) -> F:
|
||||
return gather_metrics(
|
||||
name=name,
|
||||
func=f,
|
||||
)
|
||||
|
||||
return wrapper
|
||||
else:
|
||||
# To make mypy type narrow F | None -> F
|
||||
non_optional_func = func
|
||||
|
||||
@wraps(non_optional_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
from timeit import default_timer as timer
|
||||
|
||||
exec_start = timer()
|
||||
ctx = get_script_run_ctx(suppress_warning=True)
|
||||
|
||||
tracking_activated = (
|
||||
ctx is not None
|
||||
and ctx.gather_usage_stats
|
||||
and not ctx.command_tracking_deactivated
|
||||
and len(ctx.tracked_commands)
|
||||
< _MAX_TRACKED_COMMANDS # Prevent too much memory usage
|
||||
)
|
||||
|
||||
command_telemetry: Command | None = None
|
||||
# This flag is needed to make sure that only the command (the outermost command)
|
||||
# that deactivated tracking (via ctx.command_tracking_deactivated) is able to reset it
|
||||
# again. This is important to prevent nested commands from reactivating tracking.
|
||||
# At this point, we don't know yet if the command will deactivated tracking.
|
||||
has_set_command_tracking_deactivated = False
|
||||
|
||||
if ctx and tracking_activated:
|
||||
try:
|
||||
command_telemetry = _get_command_telemetry(
|
||||
non_optional_func, name, *args, **kwargs
|
||||
)
|
||||
|
||||
if (
|
||||
command_telemetry.name not in ctx.tracked_commands_counter
|
||||
or ctx.tracked_commands_counter[command_telemetry.name]
|
||||
< _MAX_TRACKED_PER_COMMAND
|
||||
):
|
||||
ctx.tracked_commands.append(command_telemetry)
|
||||
ctx.tracked_commands_counter.update([command_telemetry.name])
|
||||
# Deactivate tracking to prevent calls inside already tracked commands
|
||||
ctx.command_tracking_deactivated = True
|
||||
# The ctx.command_tracking_deactivated flag was set to True,
|
||||
# we also need to set has_set_command_tracking_deactivated to True
|
||||
# to make sure that this command is able to reset it again.
|
||||
has_set_command_tracking_deactivated = True
|
||||
except Exception as ex:
|
||||
# Always capture all exceptions since we want to make sure that
|
||||
# the telemetry never causes any issues.
|
||||
_LOGGER.debug("Failed to collect command telemetry", exc_info=ex)
|
||||
try:
|
||||
result = non_optional_func(*args, **kwargs)
|
||||
except RerunException as ex:
|
||||
# Duplicated from below, because static analysis tools get confused
|
||||
# by deferring the rethrow.
|
||||
if tracking_activated and command_telemetry:
|
||||
command_telemetry.time = to_microseconds(timer() - exec_start)
|
||||
raise ex
|
||||
finally:
|
||||
# Activate tracking again if command executes without any exceptions
|
||||
# we only want to do that if this command has set the
|
||||
# flag to deactivate tracking.
|
||||
if ctx and has_set_command_tracking_deactivated:
|
||||
ctx.command_tracking_deactivated = False
|
||||
|
||||
if tracking_activated and command_telemetry:
|
||||
# Set the execution time to the measured value
|
||||
command_telemetry.time = to_microseconds(timer() - exec_start)
|
||||
|
||||
return result
|
||||
|
||||
with contextlib.suppress(AttributeError):
|
||||
# Make this a well-behaved decorator by preserving important function
|
||||
# attributes.
|
||||
wrapped_func.__dict__.update(non_optional_func.__dict__)
|
||||
wrapped_func.__signature__ = inspect.signature(non_optional_func) # type: ignore
|
||||
return cast("F", wrapped_func)
|
||||
|
||||
|
||||
def create_page_profile_message(
|
||||
commands: list[Command],
|
||||
exec_time: int,
|
||||
prep_time: int,
|
||||
uncaught_exception: str | None = None,
|
||||
) -> ForwardMsg:
|
||||
"""Create and return the full PageProfile ForwardMsg."""
|
||||
msg = ForwardMsg()
|
||||
page_profile = msg.page_profile
|
||||
|
||||
page_profile.commands.extend(commands)
|
||||
page_profile.exec_time = exec_time
|
||||
page_profile.prep_time = prep_time
|
||||
|
||||
page_profile.headless = config.get_option("server.headless")
|
||||
|
||||
# Collect all config options that have been manually set
|
||||
config_options: set[str] = set()
|
||||
if config._config_options:
|
||||
for option_name in config._config_options.keys():
|
||||
if not config.is_manually_set(option_name):
|
||||
# We only care about manually defined options
|
||||
continue
|
||||
|
||||
config_option = config._config_options[option_name]
|
||||
if config_option.is_default:
|
||||
option_name = f"{option_name}:default"
|
||||
config_options.add(option_name)
|
||||
|
||||
page_profile.config.extend(config_options)
|
||||
|
||||
# Check the predefined set of modules for attribution
|
||||
attributions: set[str] = {
|
||||
attribution
|
||||
for attribution in _ATTRIBUTIONS_TO_CHECK
|
||||
if attribution in sys.modules
|
||||
}
|
||||
|
||||
page_profile.os = str(sys.platform)
|
||||
page_profile.timezone = str(time.tzname)
|
||||
page_profile.attributions.extend(attributions)
|
||||
|
||||
if uncaught_exception:
|
||||
page_profile.uncaught_exception = uncaught_exception
|
||||
|
||||
if ctx := get_script_run_ctx():
|
||||
page_profile.is_fragment_run = bool(ctx.fragment_ids_this_run)
|
||||
|
||||
return msg
|
||||
@@ -0,0 +1,162 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from streamlit.util import calc_md5
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.runtime.scriptrunner.script_cache import ScriptCache
|
||||
from streamlit.source_util import PageHash, PageInfo, PageName, ScriptPath
|
||||
|
||||
|
||||
class PagesManager:
|
||||
"""
|
||||
PagesManager is responsible for managing the set of pages that make up
|
||||
the entire application. At the start we assume the main script is the
|
||||
only page. As the script runs, the main script can call `st.navigation`
|
||||
to set the set of pages that make up the app.
|
||||
"""
|
||||
|
||||
uses_pages_directory: bool | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
main_script_path: ScriptPath,
|
||||
script_cache: ScriptCache | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._main_script_path = main_script_path
|
||||
self._main_script_hash: PageHash = calc_md5(main_script_path)
|
||||
self._script_cache = script_cache
|
||||
self._intended_page_script_hash: PageHash | None = None
|
||||
self._intended_page_name: PageName | None = None
|
||||
self._current_page_script_hash: PageHash = ""
|
||||
self._pages: dict[PageHash, PageInfo] | None = None
|
||||
# A relic of v1 of Multipage apps, we performed special handling
|
||||
# for apps with a pages directory. We will keep this flag around
|
||||
# for now to maintain the behavior for apps that were created with
|
||||
# the pages directory feature.
|
||||
#
|
||||
# NOTE: we will update the feature if the flag has not been set
|
||||
# this means that if users use v2 behavior, the flag will
|
||||
# always be set to False
|
||||
if PagesManager.uses_pages_directory is None:
|
||||
PagesManager.uses_pages_directory = Path(
|
||||
self.main_script_parent / "pages"
|
||||
).exists()
|
||||
|
||||
@property
|
||||
def main_script_path(self) -> ScriptPath:
|
||||
return self._main_script_path
|
||||
|
||||
@property
|
||||
def main_script_parent(self) -> Path:
|
||||
return Path(self._main_script_path).parent
|
||||
|
||||
@property
|
||||
def main_script_hash(self) -> PageHash:
|
||||
return self._main_script_hash
|
||||
|
||||
@property
|
||||
def current_page_script_hash(self) -> PageHash:
|
||||
return self._current_page_script_hash
|
||||
|
||||
@property
|
||||
def intended_page_name(self) -> PageName | None:
|
||||
return self._intended_page_name
|
||||
|
||||
@property
|
||||
def intended_page_script_hash(self) -> PageHash | None:
|
||||
return self._intended_page_script_hash
|
||||
|
||||
def set_current_page_script_hash(self, page_script_hash: PageHash) -> None:
|
||||
self._current_page_script_hash = page_script_hash
|
||||
|
||||
def get_main_page(self) -> PageInfo:
|
||||
return {
|
||||
"script_path": self._main_script_path,
|
||||
"page_script_hash": self._main_script_hash,
|
||||
}
|
||||
|
||||
def set_script_intent(
|
||||
self, page_script_hash: PageHash, page_name: PageName
|
||||
) -> None:
|
||||
self._intended_page_script_hash = page_script_hash
|
||||
self._intended_page_name = page_name
|
||||
|
||||
def get_initial_active_script(
|
||||
self, page_script_hash: PageHash, page_name: PageName
|
||||
) -> PageInfo | None:
|
||||
return {
|
||||
# We always run the main script in V2 as it's the common code
|
||||
"script_path": self.main_script_path,
|
||||
"page_script_hash": page_script_hash
|
||||
or self.main_script_hash, # Default Hash
|
||||
}
|
||||
|
||||
def get_pages(self) -> dict[PageHash, PageInfo]:
|
||||
# If pages are not set, provide the common page info where
|
||||
# - the main script path is the executing script to start
|
||||
# - the page script hash and name reflects the intended page requested
|
||||
return self._pages or {
|
||||
self.main_script_hash: {
|
||||
"page_script_hash": self.intended_page_script_hash or "",
|
||||
"page_name": self.intended_page_name or "",
|
||||
"icon": "",
|
||||
"script_path": self.main_script_path,
|
||||
}
|
||||
}
|
||||
|
||||
def set_pages(self, pages: dict[PageHash, PageInfo]) -> None:
|
||||
self._pages = pages
|
||||
|
||||
def get_page_script(self, fallback_page_hash: PageHash = "") -> PageInfo | None:
|
||||
if self._pages is None:
|
||||
return None
|
||||
|
||||
if self.intended_page_script_hash:
|
||||
# We assume that if initial page hash is specified, that a page should
|
||||
# exist, so we check out the page script hash or the default page hash
|
||||
# as a backup
|
||||
return self._pages.get(
|
||||
self.intended_page_script_hash,
|
||||
self._pages.get(fallback_page_hash, None),
|
||||
)
|
||||
elif self.intended_page_name:
|
||||
# If a user navigates directly to a non-main page of an app, the
|
||||
# the page name can identify the page script to run
|
||||
return next(
|
||||
filter(
|
||||
# There seems to be this weird bug with mypy where it
|
||||
# thinks that p can be None (which is impossible given the
|
||||
# types of pages), so we add `p and` at the beginning of
|
||||
# the predicate to circumvent this.
|
||||
lambda p: p and (p["url_pathname"] == self.intended_page_name),
|
||||
self._pages.values(),
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
return self._pages.get(fallback_page_hash, None)
|
||||
|
||||
def get_page_script_byte_code(self, script_path: str) -> Any:
|
||||
if self._script_cache is None:
|
||||
# Returning an empty string for an empty script
|
||||
return ""
|
||||
|
||||
return self._script_cache.get_bytecode(script_path)
|
||||
792
myenv/lib/python3.11/site-packages/streamlit/runtime/runtime.py
Normal file
792
myenv/lib/python3.11/site-packages/streamlit/runtime/runtime.py
Normal file
@@ -0,0 +1,792 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Final, NamedTuple
|
||||
|
||||
from streamlit import config
|
||||
from streamlit.components.lib.local_component_registry import LocalComponentRegistry
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
from streamlit.runtime.app_session import AppSession
|
||||
from streamlit.runtime.caching import (
|
||||
get_data_cache_stats_provider,
|
||||
get_resource_cache_stats_provider,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.local_disk_cache_storage import (
|
||||
LocalDiskCacheStorageManager,
|
||||
)
|
||||
from streamlit.runtime.forward_msg_cache import (
|
||||
ForwardMsgCache,
|
||||
create_reference_msg,
|
||||
populate_hash_if_needed,
|
||||
)
|
||||
from streamlit.runtime.media_file_manager import MediaFileManager
|
||||
from streamlit.runtime.memory_session_storage import MemorySessionStorage
|
||||
from streamlit.runtime.runtime_util import is_cacheable_msg
|
||||
from streamlit.runtime.script_data import ScriptData
|
||||
from streamlit.runtime.scriptrunner.script_cache import ScriptCache
|
||||
from streamlit.runtime.session_manager import (
|
||||
ActiveSessionInfo,
|
||||
SessionClient,
|
||||
SessionClientDisconnectedError,
|
||||
SessionManager,
|
||||
SessionStorage,
|
||||
)
|
||||
from streamlit.runtime.state import (
|
||||
SCRIPT_RUN_WITHOUT_ERRORS_KEY,
|
||||
SessionStateStatProvider,
|
||||
)
|
||||
from streamlit.runtime.stats import StatsManager
|
||||
from streamlit.runtime.websocket_session_manager import WebsocketSessionManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable
|
||||
|
||||
from streamlit.components.types.base_component_registry import BaseComponentRegistry
|
||||
from streamlit.proto.BackMsg_pb2 import BackMsg
|
||||
from streamlit.runtime.caching.storage import CacheStorageManager
|
||||
from streamlit.runtime.media_file_storage import MediaFileStorage
|
||||
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
|
||||
|
||||
# Wait for the script run result for 60s and if no result is available give up
|
||||
SCRIPT_RUN_CHECK_TIMEOUT: Final = 60
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
|
||||
class RuntimeStoppedError(Exception):
|
||||
"""Raised by operations on a Runtime instance that is stopped."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RuntimeConfig:
|
||||
"""Config options for StreamlitRuntime."""
|
||||
|
||||
# The filesystem path of the Streamlit script to run.
|
||||
script_path: str
|
||||
|
||||
# DEPRECATED: We need to keep this field around for compatibility reasons, but we no
|
||||
# longer use this anywhere.
|
||||
command_line: str | None
|
||||
|
||||
# The storage backend for Streamlit's MediaFileManager.
|
||||
media_file_storage: MediaFileStorage
|
||||
|
||||
# The upload file manager
|
||||
uploaded_file_manager: UploadedFileManager
|
||||
|
||||
# The cache storage backend for Streamlit's st.cache_data.
|
||||
cache_storage_manager: CacheStorageManager = field(
|
||||
default_factory=LocalDiskCacheStorageManager
|
||||
)
|
||||
|
||||
# The ComponentRegistry instance to use.
|
||||
component_registry: BaseComponentRegistry = field(
|
||||
default_factory=LocalComponentRegistry
|
||||
)
|
||||
|
||||
# The SessionManager class to be used.
|
||||
session_manager_class: type[SessionManager] = WebsocketSessionManager
|
||||
|
||||
# The SessionStorage instance for the SessionManager to use.
|
||||
session_storage: SessionStorage = field(default_factory=MemorySessionStorage)
|
||||
|
||||
# True if the command used to start Streamlit was `streamlit hello`.
|
||||
is_hello: bool = False
|
||||
|
||||
# TODO(vdonato): Eventually add a new fragment_storage_class field enabling the code
|
||||
# creating a new Streamlit Runtime to configure the FragmentStorage instances
|
||||
# created by each new AppSession. We choose not to do this for now to avoid adding
|
||||
# additional complexity to RuntimeConfig/SessionManager/etc when it's unlikely
|
||||
# we'll have a custom implementation of this class anytime soon.
|
||||
|
||||
|
||||
class RuntimeState(Enum):
|
||||
INITIAL = "INITIAL"
|
||||
NO_SESSIONS_CONNECTED = "NO_SESSIONS_CONNECTED"
|
||||
ONE_OR_MORE_SESSIONS_CONNECTED = "ONE_OR_MORE_SESSIONS_CONNECTED"
|
||||
STOPPING = "STOPPING"
|
||||
STOPPED = "STOPPED"
|
||||
|
||||
|
||||
class AsyncObjects(NamedTuple):
|
||||
"""Container for all asyncio objects that Runtime manages.
|
||||
These cannot be initialized until the Runtime's eventloop is assigned.
|
||||
"""
|
||||
|
||||
# The eventloop that Runtime is running on.
|
||||
eventloop: asyncio.AbstractEventLoop
|
||||
|
||||
# Set after Runtime.stop() is called. Never cleared.
|
||||
must_stop: asyncio.Event
|
||||
|
||||
# Set when a client connects; cleared when we have no connected clients.
|
||||
has_connection: asyncio.Event
|
||||
|
||||
# Set after a ForwardMsg is enqueued; cleared when we flush ForwardMsgs.
|
||||
need_send_data: asyncio.Event
|
||||
|
||||
# Completed when the Runtime has started.
|
||||
started: asyncio.Future[None]
|
||||
|
||||
# Completed when the Runtime has stopped.
|
||||
stopped: asyncio.Future[None]
|
||||
|
||||
|
||||
class Runtime:
|
||||
_instance: Runtime | None = None
|
||||
|
||||
@classmethod
|
||||
def instance(cls) -> Runtime:
|
||||
"""Return the singleton Runtime instance. Raise an Error if the
|
||||
Runtime hasn't been created yet.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
raise RuntimeError("Runtime hasn't been created!")
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def exists(cls) -> bool:
|
||||
"""True if the singleton Runtime instance has been created.
|
||||
|
||||
When a Streamlit app is running in "raw mode" - that is, when the
|
||||
app is run via `python app.py` instead of `streamlit run app.py` -
|
||||
the Runtime will not exist, and various Streamlit functions need
|
||||
to adapt.
|
||||
"""
|
||||
return cls._instance is not None
|
||||
|
||||
def __init__(self, config: RuntimeConfig):
|
||||
"""Create a Runtime instance. It won't be started yet.
|
||||
|
||||
Runtime is *not* thread-safe. Its public methods are generally
|
||||
safe to call only on the same thread that its event loop runs on.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config
|
||||
Config options.
|
||||
"""
|
||||
if Runtime._instance is not None:
|
||||
raise RuntimeError("Runtime instance already exists!")
|
||||
Runtime._instance = self
|
||||
|
||||
# Will be created when we start.
|
||||
self._async_objs: AsyncObjects | None = None
|
||||
|
||||
# The task that runs our main loop. We need to save a reference
|
||||
# to it so that it doesn't get garbage collected while running.
|
||||
self._loop_coroutine_task: asyncio.Task[None] | None = None
|
||||
|
||||
self._main_script_path = config.script_path
|
||||
self._is_hello = config.is_hello
|
||||
|
||||
self._state = RuntimeState.INITIAL
|
||||
|
||||
# Initialize managers
|
||||
self._component_registry = config.component_registry
|
||||
self._message_cache = ForwardMsgCache()
|
||||
self._uploaded_file_mgr = config.uploaded_file_manager
|
||||
self._media_file_mgr = MediaFileManager(storage=config.media_file_storage)
|
||||
self._cache_storage_manager = config.cache_storage_manager
|
||||
self._script_cache = ScriptCache()
|
||||
|
||||
self._session_mgr = config.session_manager_class(
|
||||
session_storage=config.session_storage,
|
||||
uploaded_file_manager=self._uploaded_file_mgr,
|
||||
script_cache=self._script_cache,
|
||||
message_enqueued_callback=self._enqueued_some_message,
|
||||
)
|
||||
|
||||
self._stats_mgr = StatsManager()
|
||||
self._stats_mgr.register_provider(get_data_cache_stats_provider())
|
||||
self._stats_mgr.register_provider(get_resource_cache_stats_provider())
|
||||
self._stats_mgr.register_provider(self._message_cache)
|
||||
self._stats_mgr.register_provider(self._uploaded_file_mgr)
|
||||
self._stats_mgr.register_provider(SessionStateStatProvider(self._session_mgr))
|
||||
|
||||
@property
|
||||
def state(self) -> RuntimeState:
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def component_registry(self) -> BaseComponentRegistry:
|
||||
return self._component_registry
|
||||
|
||||
@property
|
||||
def message_cache(self) -> ForwardMsgCache:
|
||||
return self._message_cache
|
||||
|
||||
@property
|
||||
def uploaded_file_mgr(self) -> UploadedFileManager:
|
||||
return self._uploaded_file_mgr
|
||||
|
||||
@property
|
||||
def cache_storage_manager(self) -> CacheStorageManager:
|
||||
return self._cache_storage_manager
|
||||
|
||||
@property
|
||||
def media_file_mgr(self) -> MediaFileManager:
|
||||
return self._media_file_mgr
|
||||
|
||||
@property
|
||||
def stats_mgr(self) -> StatsManager:
|
||||
return self._stats_mgr
|
||||
|
||||
@property
|
||||
def stopped(self) -> Awaitable[None]:
|
||||
"""A Future that completes when the Runtime's run loop has exited."""
|
||||
return self._get_async_objs().stopped
|
||||
|
||||
# NOTE: A few Runtime methods listed as threadsafe (get_client and
|
||||
# is_active_session) currently rely on the implementation detail that
|
||||
# WebsocketSessionManager's get_active_session_info and is_active_session methods
|
||||
# happen to be threadsafe. This may change with future SessionManager implementations,
|
||||
# at which point we'll need to formalize our thread safety rules for each
|
||||
# SessionManager method.
|
||||
def get_client(self, session_id: str) -> SessionClient | None:
|
||||
"""Get the SessionClient for the given session_id, or None
|
||||
if no such session exists.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: SAFE. May be called on any thread.
|
||||
"""
|
||||
session_info = self._session_mgr.get_active_session_info(session_id)
|
||||
if session_info is None:
|
||||
return None
|
||||
return session_info.client
|
||||
|
||||
def clear_user_info_for_session(self, session_id: str) -> None:
|
||||
"""Clear the user_info for the given session_id.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: SAFE. May be called on any thread.
|
||||
"""
|
||||
session_info = self._session_mgr.get_session_info(session_id)
|
||||
if session_info is not None:
|
||||
session_info.session.clear_user_info()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the runtime. This must be called only once, before
|
||||
any other functions are called.
|
||||
|
||||
When this coroutine returns, Streamlit is ready to accept new sessions.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: UNSAFE. Must be called on the eventloop thread.
|
||||
"""
|
||||
|
||||
# Create our AsyncObjects. We need to have a running eventloop to
|
||||
# instantiate our various synchronization primitives.
|
||||
async_objs = AsyncObjects(
|
||||
eventloop=asyncio.get_running_loop(),
|
||||
must_stop=asyncio.Event(),
|
||||
has_connection=asyncio.Event(),
|
||||
need_send_data=asyncio.Event(),
|
||||
started=asyncio.Future(),
|
||||
stopped=asyncio.Future(),
|
||||
)
|
||||
self._async_objs = async_objs
|
||||
|
||||
self._loop_coroutine_task = asyncio.create_task(
|
||||
self._loop_coroutine(), name="Runtime.loop_coroutine"
|
||||
)
|
||||
|
||||
await async_objs.started
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Request that Streamlit close all sessions and stop running.
|
||||
Note that Streamlit won't stop running immediately.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: SAFE. May be called from any thread.
|
||||
"""
|
||||
|
||||
async_objs = self._get_async_objs()
|
||||
|
||||
def stop_on_eventloop():
|
||||
if self._state in (RuntimeState.STOPPING, RuntimeState.STOPPED):
|
||||
return
|
||||
|
||||
_LOGGER.debug("Runtime stopping...")
|
||||
self._set_state(RuntimeState.STOPPING)
|
||||
async_objs.must_stop.set()
|
||||
|
||||
async_objs.eventloop.call_soon_threadsafe(stop_on_eventloop)
|
||||
|
||||
def is_active_session(self, session_id: str) -> bool:
|
||||
"""True if the session_id belongs to an active session.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: SAFE. May be called on any thread.
|
||||
"""
|
||||
return self._session_mgr.is_active_session(session_id)
|
||||
|
||||
def connect_session(
|
||||
self,
|
||||
client: SessionClient,
|
||||
user_info: dict[str, str | bool | None],
|
||||
existing_session_id: str | None = None,
|
||||
session_id_override: str | None = None,
|
||||
) -> str:
|
||||
"""Create a new session (or connect to an existing one) and return its unique ID.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client
|
||||
A concrete SessionClient implementation for communicating with
|
||||
the session's client.
|
||||
user_info
|
||||
A dict that contains information about the session's user. For now,
|
||||
it only (optionally) contains the user's email address.
|
||||
|
||||
{
|
||||
"email": "example@example.com"
|
||||
}
|
||||
existing_session_id
|
||||
The ID of an existing session to reconnect to. If one is not provided, a new
|
||||
session is created. Note that whether the Runtime's SessionManager supports
|
||||
reconnecting to an existing session depends on the SessionManager that this
|
||||
runtime is configured with.
|
||||
session_id_override
|
||||
The ID to assign to a new session being created with this method. Setting
|
||||
this can be useful when the service that a Streamlit Runtime is running in
|
||||
wants to tie the lifecycle of a Streamlit session to some other session-like
|
||||
object that it manages. Only one of existing_session_id and
|
||||
session_id_override should be set.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The session's unique string ID.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: UNSAFE. Must be called on the eventloop thread.
|
||||
"""
|
||||
assert not (existing_session_id and session_id_override), (
|
||||
"Only one of existing_session_id and session_id_override should be set!"
|
||||
)
|
||||
|
||||
if self._state in (RuntimeState.STOPPING, RuntimeState.STOPPED):
|
||||
raise RuntimeStoppedError(f"Can't connect_session (state={self._state})")
|
||||
|
||||
session_id = self._session_mgr.connect_session(
|
||||
client=client,
|
||||
script_data=ScriptData(self._main_script_path, self._is_hello),
|
||||
user_info=user_info,
|
||||
existing_session_id=existing_session_id,
|
||||
session_id_override=session_id_override,
|
||||
)
|
||||
self._set_state(RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED)
|
||||
self._get_async_objs().has_connection.set()
|
||||
|
||||
return session_id
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
client: SessionClient,
|
||||
user_info: dict[str, str | bool | None],
|
||||
existing_session_id: str | None = None,
|
||||
session_id_override: str | None = None,
|
||||
) -> str:
|
||||
"""Create a new session (or connect to an existing one) and return its unique ID.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This method is simply an alias for connect_session added for backwards
|
||||
compatibility.
|
||||
"""
|
||||
_LOGGER.warning("create_session is deprecated! Use connect_session instead.")
|
||||
return self.connect_session(
|
||||
client=client,
|
||||
user_info=user_info,
|
||||
existing_session_id=existing_session_id,
|
||||
session_id_override=session_id_override,
|
||||
)
|
||||
|
||||
def close_session(self, session_id: str) -> None:
|
||||
"""Close and completely shut down a session.
|
||||
|
||||
This differs from disconnect_session in that it always completely shuts down a
|
||||
session, permanently losing any associated state (session state, uploaded files,
|
||||
etc.).
|
||||
|
||||
This function may be called multiple times for the same session,
|
||||
which is not an error. (Subsequent calls just no-op.)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The session's unique ID.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: UNSAFE. Must be called on the eventloop thread.
|
||||
"""
|
||||
session_info = self._session_mgr.get_session_info(session_id)
|
||||
if session_info:
|
||||
self._message_cache.remove_refs_for_session(session_info.session)
|
||||
self._session_mgr.close_session(session_id)
|
||||
self._on_session_disconnected()
|
||||
|
||||
def disconnect_session(self, session_id: str) -> None:
|
||||
"""Disconnect a session. It will stop producing ForwardMsgs.
|
||||
|
||||
Differs from close_session because disconnected sessions can be reconnected to
|
||||
for a brief window (depending on the SessionManager/SessionStorage
|
||||
implementations used by the runtime).
|
||||
|
||||
This function may be called multiple times for the same session,
|
||||
which is not an error. (Subsequent calls just no-op.)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The session's unique ID.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: UNSAFE. Must be called on the eventloop thread.
|
||||
"""
|
||||
session_info = self._session_mgr.get_active_session_info(session_id)
|
||||
if session_info:
|
||||
# NOTE: Ideally, we'd like to keep ForwardMsgCache refs for a session around
|
||||
# when a session is disconnected (and defer their cleanup until the session
|
||||
# is garbage collected), but this would be difficult to do as the
|
||||
# ForwardMsgCache is not thread safe, and we have no guarantee that the
|
||||
# garbage collector will only run on the eventloop thread. Because of this,
|
||||
# we clean up refs now and accept the risk that we're deleting cache entries
|
||||
# that will be useful once the browser tab reconnects.
|
||||
self._message_cache.remove_refs_for_session(session_info.session)
|
||||
self._session_mgr.disconnect_session(session_id)
|
||||
self._on_session_disconnected()
|
||||
|
||||
def handle_backmsg(self, session_id: str, msg: BackMsg) -> None:
|
||||
"""Send a BackMsg to an active session.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The session's unique ID.
|
||||
msg
|
||||
The BackMsg to deliver to the session.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: UNSAFE. Must be called on the eventloop thread.
|
||||
"""
|
||||
if self._state in (RuntimeState.STOPPING, RuntimeState.STOPPED):
|
||||
raise RuntimeStoppedError(f"Can't handle_backmsg (state={self._state})")
|
||||
|
||||
session_info = self._session_mgr.get_active_session_info(session_id)
|
||||
if session_info is None:
|
||||
_LOGGER.debug(
|
||||
"Discarding BackMsg for disconnected session (id=%s)", session_id
|
||||
)
|
||||
return
|
||||
|
||||
session_info.session.handle_backmsg(msg)
|
||||
|
||||
def handle_backmsg_deserialization_exception(
|
||||
self, session_id: str, exc: BaseException
|
||||
) -> None:
|
||||
"""Handle an Exception raised during deserialization of a BackMsg.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The session's unique ID.
|
||||
exc
|
||||
The Exception.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: UNSAFE. Must be called on the eventloop thread.
|
||||
"""
|
||||
if self._state in (RuntimeState.STOPPING, RuntimeState.STOPPED):
|
||||
raise RuntimeStoppedError(
|
||||
f"Can't handle_backmsg_deserialization_exception (state={self._state})"
|
||||
)
|
||||
|
||||
session_info = self._session_mgr.get_active_session_info(session_id)
|
||||
if session_info is None:
|
||||
_LOGGER.debug(
|
||||
"Discarding BackMsg Exception for disconnected session (id=%s)",
|
||||
session_id,
|
||||
)
|
||||
return
|
||||
|
||||
session_info.session.handle_backmsg_exception(exc)
|
||||
|
||||
@property
|
||||
async def is_ready_for_browser_connection(self) -> tuple[bool, str]:
|
||||
if self._state not in (
|
||||
RuntimeState.INITIAL,
|
||||
RuntimeState.STOPPING,
|
||||
RuntimeState.STOPPED,
|
||||
):
|
||||
return True, "ok"
|
||||
|
||||
return False, "unavailable"
|
||||
|
||||
async def does_script_run_without_error(self) -> tuple[bool, str]:
|
||||
"""Load and execute the app's script to verify it runs without an error.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(True, "ok") if the script completes without error, or (False, err_msg)
|
||||
if the script raises an exception.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: UNSAFE. Must be called on the eventloop thread.
|
||||
"""
|
||||
# NOTE: We create an AppSession directly here instead of using the
|
||||
# SessionManager intentionally. This isn't a "real" session and is only being
|
||||
# used to test that the script runs without error.
|
||||
session = AppSession(
|
||||
script_data=ScriptData(self._main_script_path, self._is_hello),
|
||||
uploaded_file_manager=self._uploaded_file_mgr,
|
||||
script_cache=self._script_cache,
|
||||
message_enqueued_callback=self._enqueued_some_message,
|
||||
user_info={"email": "test@example.com"},
|
||||
)
|
||||
|
||||
try:
|
||||
session.request_rerun(None)
|
||||
|
||||
now = time.perf_counter()
|
||||
while ( # noqa: ASYNC110
|
||||
SCRIPT_RUN_WITHOUT_ERRORS_KEY not in session.session_state
|
||||
and (time.perf_counter() - now) < SCRIPT_RUN_CHECK_TIMEOUT
|
||||
):
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
if SCRIPT_RUN_WITHOUT_ERRORS_KEY not in session.session_state:
|
||||
return False, "timeout"
|
||||
|
||||
ok = session.session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY]
|
||||
msg = "ok" if ok else "error"
|
||||
|
||||
return ok, msg
|
||||
finally:
|
||||
session.shutdown()
|
||||
|
||||
def _set_state(self, new_state: RuntimeState) -> None:
|
||||
_LOGGER.debug("Runtime state: %s -> %s", self._state, new_state)
|
||||
self._state = new_state
|
||||
|
||||
async def _loop_coroutine(self) -> None:
|
||||
"""The main Runtime loop.
|
||||
|
||||
This function won't exit until `stop` is called.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: UNSAFE. Must be called on the eventloop thread.
|
||||
"""
|
||||
|
||||
async_objs = self._get_async_objs()
|
||||
|
||||
try:
|
||||
if self._state == RuntimeState.INITIAL:
|
||||
self._set_state(RuntimeState.NO_SESSIONS_CONNECTED)
|
||||
elif self._state == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError(f"Bad Runtime state at start: {self._state}")
|
||||
|
||||
# Signal that we're started and ready to accept sessions
|
||||
async_objs.started.set_result(None)
|
||||
|
||||
while not async_objs.must_stop.is_set():
|
||||
if self._state == RuntimeState.NO_SESSIONS_CONNECTED: # type: ignore[comparison-overlap]
|
||||
# mypy 1.4 incorrectly thinks this if-clause is unreachable,
|
||||
# because it thinks self._state must be INITIAL | ONE_OR_MORE_SESSIONS_CONNECTED.
|
||||
|
||||
# Wait for new websocket connections (new sessions):
|
||||
_, pending_tasks = await asyncio.wait( # type: ignore[unreachable]
|
||||
(
|
||||
asyncio.create_task(async_objs.must_stop.wait()),
|
||||
asyncio.create_task(async_objs.has_connection.wait()),
|
||||
),
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
# Clean up pending tasks to avoid memory leaks
|
||||
for task in pending_tasks:
|
||||
task.cancel()
|
||||
elif self._state == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED:
|
||||
async_objs.need_send_data.clear()
|
||||
|
||||
for active_session_info in self._session_mgr.list_active_sessions():
|
||||
msg_list = active_session_info.session.flush_browser_queue()
|
||||
for msg in msg_list:
|
||||
try:
|
||||
self._send_message(active_session_info, msg)
|
||||
except SessionClientDisconnectedError:
|
||||
self._session_mgr.disconnect_session(
|
||||
active_session_info.session.id
|
||||
)
|
||||
|
||||
# Yield for a tick after sending a message.
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Yield for a few milliseconds between session message
|
||||
# flushing.
|
||||
await asyncio.sleep(0.01)
|
||||
else:
|
||||
# Break out of the thread loop if we encounter any other state.
|
||||
break
|
||||
|
||||
# Wait for new proto messages that need to be sent out:
|
||||
_, pending_tasks = await asyncio.wait(
|
||||
(
|
||||
asyncio.create_task(async_objs.must_stop.wait()),
|
||||
asyncio.create_task(async_objs.need_send_data.wait()),
|
||||
),
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
# We need to cancel the pending tasks (the `must_stop` one in most situations).
|
||||
# Otherwise, this would stack up one waiting task per loop
|
||||
# (e.g. per forward message). These tasks cannot be garbage collected
|
||||
# causing an increase in memory (-> memory leak).
|
||||
for task in pending_tasks:
|
||||
task.cancel()
|
||||
|
||||
# Shut down all AppSessions.
|
||||
for session_info in self._session_mgr.list_sessions():
|
||||
# NOTE: We want to fully shut down sessions when the runtime stops for
|
||||
# now, but this may change in the future if/when our notion of a session
|
||||
# is no longer so tightly coupled to a browser tab.
|
||||
self._session_mgr.close_session(session_info.session.id)
|
||||
|
||||
self._set_state(RuntimeState.STOPPED)
|
||||
async_objs.stopped.set_result(None)
|
||||
|
||||
except Exception as e:
|
||||
async_objs.stopped.set_exception(e)
|
||||
traceback.print_exc()
|
||||
_LOGGER.info(
|
||||
"""
|
||||
Please report this bug at https://github.com/streamlit/streamlit/issues.
|
||||
"""
|
||||
)
|
||||
|
||||
def _send_message(self, session_info: ActiveSessionInfo, msg: ForwardMsg) -> None:
|
||||
"""Send a message to a client.
|
||||
|
||||
If the client is likely to have already cached the message, we may
|
||||
instead send a "reference" message that contains only the hash of the
|
||||
message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_info : ActiveSessionInfo
|
||||
The ActiveSessionInfo associated with websocket
|
||||
msg : ForwardMsg
|
||||
The message to send to the client
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: UNSAFE. Must be called on the eventloop thread.
|
||||
"""
|
||||
msg.metadata.cacheable = is_cacheable_msg(msg)
|
||||
msg_to_send = msg
|
||||
if msg.metadata.cacheable:
|
||||
populate_hash_if_needed(msg)
|
||||
|
||||
if self._message_cache.has_message_reference(
|
||||
msg, session_info.session, session_info.script_run_count
|
||||
):
|
||||
# This session has probably cached this message. Send
|
||||
# a reference instead.
|
||||
_LOGGER.debug("Sending cached message ref (hash=%s)", msg.hash)
|
||||
msg_to_send = create_reference_msg(msg)
|
||||
|
||||
# Cache the message so it can be referenced in the future.
|
||||
# If the message is already cached, this will reset its
|
||||
# age.
|
||||
_LOGGER.debug("Caching message (hash=%s)", msg.hash)
|
||||
self._message_cache.add_message(
|
||||
msg, session_info.session, session_info.script_run_count
|
||||
)
|
||||
|
||||
# If this was a `script_finished` message, we increment the
|
||||
# script_run_count for this session, and update the cache
|
||||
if msg.WhichOneof("type") == "script_finished" and (
|
||||
msg.script_finished == ForwardMsg.FINISHED_SUCCESSFULLY
|
||||
or (
|
||||
config.get_option(
|
||||
"global.includeFragmentRunsInForwardMessageCacheCount"
|
||||
)
|
||||
and msg.script_finished == ForwardMsg.FINISHED_FRAGMENT_RUN_SUCCESSFULLY
|
||||
)
|
||||
):
|
||||
_LOGGER.debug(
|
||||
"Script run finished successfully; "
|
||||
"removing expired entries from MessageCache "
|
||||
"(max_age=%s)",
|
||||
config.get_option("global.maxCachedMessageAge"),
|
||||
)
|
||||
session_info.script_run_count += 1
|
||||
self._message_cache.remove_expired_entries_for_session(
|
||||
session_info.session, session_info.script_run_count
|
||||
)
|
||||
|
||||
# Ship it off!
|
||||
session_info.client.write_forward_msg(msg_to_send)
|
||||
|
||||
def _enqueued_some_message(self) -> None:
|
||||
"""Callback called by AppSession after the AppSession has enqueued a
|
||||
message. Sets the "needs_send_data" event, which causes our core
|
||||
loop to wake up and flush client message queues.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: SAFE. May be called on any thread.
|
||||
"""
|
||||
async_objs = self._get_async_objs()
|
||||
async_objs.eventloop.call_soon_threadsafe(async_objs.need_send_data.set)
|
||||
|
||||
def _get_async_objs(self) -> AsyncObjects:
|
||||
"""Return our AsyncObjects instance. If the Runtime hasn't been
|
||||
started, this will raise an error.
|
||||
"""
|
||||
if self._async_objs is None:
|
||||
raise RuntimeError("Runtime hasn't started yet!")
|
||||
return self._async_objs
|
||||
|
||||
def _on_session_disconnected(self) -> None:
|
||||
"""Set the runtime state to NO_SESSIONS_CONNECTED if the last active
|
||||
session was disconnected.
|
||||
"""
|
||||
if (
|
||||
self._state == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED
|
||||
and self._session_mgr.num_active_sessions() == 0
|
||||
):
|
||||
self._get_async_objs().has_connection.clear()
|
||||
self._set_state(RuntimeState.NO_SESSIONS_CONNECTED)
|
||||
@@ -0,0 +1,106 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Runtime-related utility functions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from streamlit import config
|
||||
from streamlit.errors import MarkdownFormattedException, StreamlitAPIException
|
||||
from streamlit.runtime.forward_msg_cache import populate_hash_if_needed
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
|
||||
|
||||
class MessageSizeError(MarkdownFormattedException):
|
||||
"""Exception raised when a websocket message is larger than the configured limit."""
|
||||
|
||||
def __init__(self, failed_msg_str: Any):
|
||||
msg = self._get_message(failed_msg_str)
|
||||
super().__init__(msg)
|
||||
|
||||
def _get_message(self, failed_msg_str: Any) -> str:
|
||||
# This needs to have zero indentation otherwise the markdown will render incorrectly.
|
||||
return (
|
||||
f"""
|
||||
**Data of size {len(failed_msg_str) / 1e6:.1f} MB exceeds the message size limit of {get_max_message_size_bytes() / 1e6} MB.**
|
||||
|
||||
This is often caused by a large chart or dataframe. Please decrease the amount of data sent
|
||||
to the browser, or increase the limit by setting the config option `server.maxMessageSize`.
|
||||
[Click here to learn more about config options](https://docs.streamlit.io/develop/api-reference/configuration/config.toml).
|
||||
|
||||
_Note that increasing the limit may lead to long loading times and large memory consumption
|
||||
of the client's browser and the Streamlit server._
|
||||
"""
|
||||
).strip("\n")
|
||||
|
||||
|
||||
class BadDurationStringError(StreamlitAPIException):
|
||||
"""Raised when a bad duration argument string is passed."""
|
||||
|
||||
def __init__(self, duration: str):
|
||||
MarkdownFormattedException.__init__(
|
||||
self,
|
||||
"TTL string doesn't look right. It should be formatted as"
|
||||
f"`'1d2h34m'` or `2 days`, for example. Got: {duration}",
|
||||
)
|
||||
|
||||
|
||||
def is_cacheable_msg(msg: ForwardMsg) -> bool:
|
||||
"""True if the given message qualifies for caching."""
|
||||
if msg.WhichOneof("type") in {"ref_hash", "initialize"}:
|
||||
# Some message types never get cached
|
||||
return False
|
||||
return msg.ByteSize() >= int(config.get_option("global.minCachedMessageSize"))
|
||||
|
||||
|
||||
def serialize_forward_msg(msg: ForwardMsg) -> bytes:
|
||||
"""Serialize a ForwardMsg to send to a client.
|
||||
|
||||
If the message is too large, it will be converted to an exception message
|
||||
instead.
|
||||
"""
|
||||
populate_hash_if_needed(msg)
|
||||
msg_str = msg.SerializeToString()
|
||||
|
||||
if len(msg_str) > get_max_message_size_bytes():
|
||||
import streamlit.elements.exception as exception
|
||||
|
||||
# Overwrite the offending ForwardMsg.delta with an error to display.
|
||||
# This assumes that the size limit wasn't exceeded due to metadata.
|
||||
exception.marshall(msg.delta.new_element.exception, MessageSizeError(msg_str))
|
||||
msg_str = msg.SerializeToString()
|
||||
|
||||
return msg_str
|
||||
|
||||
|
||||
# This needs to be initialized lazily to avoid calling config.get_option() and
|
||||
# thus initializing config options when this file is first imported.
|
||||
_max_message_size_bytes: int | None = None
|
||||
|
||||
|
||||
def get_max_message_size_bytes() -> int:
|
||||
"""Returns the max websocket message size in bytes.
|
||||
|
||||
This will lazyload the value from the config and store it in the global symbol table.
|
||||
"""
|
||||
global _max_message_size_bytes
|
||||
|
||||
if _max_message_size_bytes is None:
|
||||
_max_message_size_bytes = config.get_option("server.maxMessageSize") * int(1e6)
|
||||
|
||||
return _max_message_size_bytes
|
||||
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScriptData:
|
||||
"""Contains parameters related to running a script."""
|
||||
|
||||
main_script_path: str
|
||||
is_hello: bool = False
|
||||
script_folder: str = field(init=False)
|
||||
name: str = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set some computed values derived from main_script_path.
|
||||
|
||||
The usage of object.__setattr__ is necessary because trying to set
|
||||
self.script_folder or self.name normally, even within the __init__ method, will
|
||||
explode since we declared this dataclass to be frozen.
|
||||
|
||||
We do this in __post_init__ so that we can use the auto-generated __init__
|
||||
method that most dataclasses use.
|
||||
"""
|
||||
main_script_path = os.path.abspath(self.main_script_path)
|
||||
script_folder = os.path.dirname(main_script_path)
|
||||
object.__setattr__(self, "script_folder", script_folder)
|
||||
|
||||
basename = os.path.basename(main_script_path)
|
||||
name = str(os.path.splitext(basename)[0])
|
||||
object.__setattr__(self, "name", name)
|
||||
@@ -0,0 +1,38 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from streamlit.runtime.scriptrunner.script_runner import ScriptRunner, ScriptRunnerEvent
|
||||
from streamlit.runtime.scriptrunner_utils.exceptions import (
|
||||
RerunException,
|
||||
StopException,
|
||||
)
|
||||
from streamlit.runtime.scriptrunner_utils.script_requests import RerunData
|
||||
from streamlit.runtime.scriptrunner_utils.script_run_context import (
|
||||
ScriptRunContext,
|
||||
add_script_run_ctx,
|
||||
enqueue_message,
|
||||
get_script_run_ctx,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RerunData",
|
||||
"ScriptRunContext",
|
||||
"add_script_run_ctx",
|
||||
"get_script_run_ctx",
|
||||
"enqueue_message",
|
||||
"RerunException",
|
||||
"ScriptRunner",
|
||||
"ScriptRunnerEvent",
|
||||
"StopException",
|
||||
]
|
||||
@@ -0,0 +1,159 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from streamlit import util
|
||||
from streamlit.delta_generator_singletons import (
|
||||
context_dg_stack,
|
||||
get_default_dg_stack_value,
|
||||
)
|
||||
from streamlit.error_util import handle_uncaught_app_exception
|
||||
from streamlit.errors import FragmentHandledException
|
||||
from streamlit.runtime.scriptrunner_utils.exceptions import (
|
||||
RerunException,
|
||||
StopException,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.runtime.scriptrunner_utils.script_requests import RerunData
|
||||
from streamlit.runtime.scriptrunner_utils.script_run_context import ScriptRunContext
|
||||
|
||||
|
||||
class modified_sys_path:
|
||||
"""A context for prepending a directory to sys.path for a second.
|
||||
|
||||
Code inspired by IPython:
|
||||
Source: https://github.com/ipython/ipython/blob/master/IPython/utils/syspathcontext.py#L42
|
||||
"""
|
||||
|
||||
def __init__(self, main_script_path: str):
|
||||
self._main_script_path = main_script_path
|
||||
self._added_path = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
def __enter__(self):
|
||||
if self._main_script_path not in sys.path:
|
||||
sys.path.insert(0, self._main_script_path)
|
||||
self._added_path = True
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
if self._added_path:
|
||||
try:
|
||||
sys.path.remove(self._main_script_path)
|
||||
except ValueError:
|
||||
# It's already removed.
|
||||
pass
|
||||
|
||||
# Returning False causes any exceptions to be re-raised.
|
||||
return False
|
||||
|
||||
|
||||
def exec_func_with_error_handling(
|
||||
func: Callable[[], Any], ctx: ScriptRunContext
|
||||
) -> tuple[
|
||||
Any | None,
|
||||
bool,
|
||||
RerunData | None,
|
||||
bool,
|
||||
Exception | None,
|
||||
]:
|
||||
"""Execute the passed function wrapped in a try/except block.
|
||||
|
||||
This function is called by the script runner to execute the user's script or
|
||||
fragment reruns, but also for the execution of fragment code in context of a normal
|
||||
app run. This wrapper ensures that handle_uncaught_exception messages show up in the
|
||||
correct context.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
The function to execute wrapped in the try/except block.
|
||||
ctx : ScriptRunContext
|
||||
The context in which the script is being run.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple
|
||||
A tuple containing:
|
||||
- The result of the passed function.
|
||||
- A boolean indicating whether the script ran without errors (RerunException and
|
||||
StopException don't count as errors).
|
||||
- The RerunData instance belonging to a RerunException if the script was
|
||||
interrupted by a RerunException.
|
||||
- A boolean indicating whether the script was stopped prematurely (False for
|
||||
RerunExceptions, True for all other exceptions).
|
||||
- The uncaught exception if one occurred, None otherwise
|
||||
"""
|
||||
run_without_errors = True
|
||||
|
||||
# This will be set to a RerunData instance if our execution
|
||||
# is interrupted by a RerunException.
|
||||
rerun_exception_data: RerunData | None = None
|
||||
|
||||
# If the script stops early, we don't want to remove unseen widgets,
|
||||
# so we track this to potentially skip session state cleanup later.
|
||||
premature_stop: bool = False
|
||||
|
||||
# The result of the passed function
|
||||
result: Any | None = None
|
||||
|
||||
# The uncaught exception if one occurred, None otherwise
|
||||
uncaught_exception: Exception | None = None
|
||||
|
||||
try:
|
||||
result = func()
|
||||
except RerunException as e:
|
||||
rerun_exception_data = e.rerun_data
|
||||
|
||||
# Since the script is about to rerun, we may need to reset our cursors/dg_stack
|
||||
# so that we write to the right place in the app. For full script runs, this
|
||||
# needs to happen in case the same thread reruns our script (a different thread
|
||||
# would automatically come with fresh cursors/dg_stack values). For fragments,
|
||||
# it doesn't matter either way since the fragment resets these values from its
|
||||
# snapshot before execution.
|
||||
ctx.cursors.clear()
|
||||
context_dg_stack.set(get_default_dg_stack_value())
|
||||
|
||||
# Interruption due to a rerun is usually from `st.rerun()`, which
|
||||
# we want to count as a script completion so triggers reset.
|
||||
# It is also possible for this to happen if fast reruns is off,
|
||||
# but this is very rare.
|
||||
premature_stop = False
|
||||
|
||||
except StopException:
|
||||
# This is thrown when the script executes `st.stop()`.
|
||||
# We don't have to do anything here.
|
||||
premature_stop = True
|
||||
except FragmentHandledException:
|
||||
run_without_errors = False
|
||||
premature_stop = True
|
||||
except Exception as ex:
|
||||
run_without_errors = False
|
||||
premature_stop = True
|
||||
handle_uncaught_app_exception(ex)
|
||||
uncaught_exception = ex
|
||||
|
||||
return (
|
||||
result,
|
||||
run_without_errors,
|
||||
rerun_exception_data,
|
||||
premature_stop,
|
||||
uncaught_exception,
|
||||
)
|
||||
@@ -0,0 +1,273 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import sys
|
||||
from typing import Any, Final
|
||||
|
||||
from streamlit import config
|
||||
|
||||
# When a Streamlit app is magicified, we insert a `magic_funcs` import near the top of
|
||||
# its module's AST:
|
||||
# import streamlit.runtime.scriptrunner.magic_funcs as __streamlitmagic__
|
||||
MAGIC_MODULE_NAME: Final = "__streamlitmagic__"
|
||||
|
||||
|
||||
def add_magic(code: str, script_path: str) -> Any:
|
||||
"""Modifies the code to support magic Streamlit commands.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
code : str
|
||||
The Python code.
|
||||
script_path : str
|
||||
The path to the script file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ast.Module
|
||||
The syntax tree for the code.
|
||||
|
||||
"""
|
||||
# Pass script_path so we get pretty exceptions.
|
||||
tree = ast.parse(code, script_path, "exec")
|
||||
|
||||
file_ends_in_semicolon = _does_file_end_in_semicolon(tree, code)
|
||||
|
||||
_modify_ast_subtree(
|
||||
tree, is_root=True, file_ends_in_semicolon=file_ends_in_semicolon
|
||||
)
|
||||
|
||||
return tree
|
||||
|
||||
|
||||
def _modify_ast_subtree(
|
||||
tree: Any,
|
||||
body_attr: str = "body",
|
||||
is_root: bool = False,
|
||||
file_ends_in_semicolon: bool = False,
|
||||
):
|
||||
"""Parses magic commands and modifies the given AST (sub)tree."""
|
||||
|
||||
body = getattr(tree, body_attr)
|
||||
|
||||
for i, node in enumerate(body):
|
||||
node_type = type(node)
|
||||
|
||||
# Recursively parses the content of the statements
|
||||
# `with` as well as function definitions.
|
||||
# Also covers their async counterparts
|
||||
if (
|
||||
node_type is ast.FunctionDef
|
||||
or node_type is ast.With
|
||||
or node_type is ast.AsyncFunctionDef
|
||||
or node_type is ast.AsyncWith
|
||||
):
|
||||
_modify_ast_subtree(node)
|
||||
|
||||
# Recursively parses the content of the statements
|
||||
# `for` and `while`.
|
||||
# Also covers their async counterparts
|
||||
elif (
|
||||
node_type is ast.For or node_type is ast.While or node_type is ast.AsyncFor
|
||||
):
|
||||
_modify_ast_subtree(node)
|
||||
_modify_ast_subtree(node, "orelse")
|
||||
|
||||
# Recursively parses methods in a class.
|
||||
elif node_type is ast.ClassDef:
|
||||
for inner_node in node.body:
|
||||
if type(inner_node) in {ast.FunctionDef, ast.AsyncFunctionDef}:
|
||||
_modify_ast_subtree(inner_node)
|
||||
|
||||
# Recursively parses the contents of try statements,
|
||||
# all their handlers (except and else) and the finally body
|
||||
elif node_type is ast.Try or (
|
||||
sys.version_info >= (3, 11) and node_type is ast.TryStar
|
||||
):
|
||||
_modify_ast_subtree(node)
|
||||
_modify_ast_subtree(node, body_attr="finalbody")
|
||||
_modify_ast_subtree(node, body_attr="orelse")
|
||||
for handler_node in node.handlers:
|
||||
_modify_ast_subtree(handler_node)
|
||||
|
||||
# Recursively parses if blocks, as well as their else/elif blocks
|
||||
# (else/elif are both mapped to orelse)
|
||||
# it intentionally does not parse the test expression.
|
||||
elif node_type is ast.If:
|
||||
_modify_ast_subtree(node)
|
||||
_modify_ast_subtree(node, "orelse")
|
||||
|
||||
elif sys.version_info >= (3, 10) and node_type is ast.Match:
|
||||
for case_node in node.cases:
|
||||
_modify_ast_subtree(case_node)
|
||||
|
||||
# Convert standalone expression nodes to st.write
|
||||
elif node_type is ast.Expr:
|
||||
value = _get_st_write_from_expr(
|
||||
node,
|
||||
i,
|
||||
parent_type=type(tree),
|
||||
is_root=is_root,
|
||||
is_last_expr=(i == len(body) - 1),
|
||||
file_ends_in_semicolon=file_ends_in_semicolon,
|
||||
)
|
||||
if value is not None:
|
||||
node.value = value
|
||||
|
||||
if is_root:
|
||||
# Import Streamlit so we can use it in the new_value above.
|
||||
_insert_import_statement(tree)
|
||||
|
||||
ast.fix_missing_locations(tree)
|
||||
|
||||
|
||||
def _insert_import_statement(tree: Any) -> None:
|
||||
"""Insert Streamlit import statement at the top(ish) of the tree."""
|
||||
|
||||
st_import = _build_st_import_statement()
|
||||
|
||||
# If the 0th node is already an import statement, put the Streamlit
|
||||
# import below that, so we don't break "from __future__ import".
|
||||
if tree.body and type(tree.body[0]) in {ast.ImportFrom, ast.Import}:
|
||||
tree.body.insert(1, st_import)
|
||||
|
||||
# If the 0th node is a docstring and the 1st is an import statement,
|
||||
# put the Streamlit import below those, so we don't break "from
|
||||
# __future__ import".
|
||||
elif (
|
||||
len(tree.body) > 1
|
||||
and (
|
||||
type(tree.body[0]) is ast.Expr
|
||||
and _is_string_constant_node(tree.body[0].value)
|
||||
)
|
||||
and type(tree.body[1]) in {ast.ImportFrom, ast.Import}
|
||||
):
|
||||
tree.body.insert(2, st_import)
|
||||
|
||||
else:
|
||||
tree.body.insert(0, st_import)
|
||||
|
||||
|
||||
def _build_st_import_statement():
|
||||
"""Build AST node for `import magic_funcs as __streamlitmagic__`."""
|
||||
return ast.Import(
|
||||
names=[
|
||||
ast.alias(
|
||||
name="streamlit.runtime.scriptrunner.magic_funcs",
|
||||
asname=MAGIC_MODULE_NAME,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _build_st_write_call(nodes):
|
||||
"""Build AST node for `__streamlitmagic__.transparent_write(*nodes)`."""
|
||||
return ast.Call(
|
||||
func=ast.Attribute(
|
||||
attr="transparent_write",
|
||||
value=ast.Name(id=MAGIC_MODULE_NAME, ctx=ast.Load()),
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
args=nodes,
|
||||
keywords=[],
|
||||
)
|
||||
|
||||
|
||||
def _get_st_write_from_expr(
|
||||
node, i, parent_type, is_root, is_last_expr, file_ends_in_semicolon
|
||||
):
|
||||
# Don't wrap function calls
|
||||
# (Unless the function call happened at the end of the root node, AND
|
||||
# magic.displayLastExprIfNoSemicolon is True. This allows us to support notebook-like
|
||||
# behavior, where we display the last function in a cell)
|
||||
if type(node.value) is ast.Call and not _is_displayable_last_expr(
|
||||
is_root, is_last_expr, file_ends_in_semicolon
|
||||
):
|
||||
return None
|
||||
|
||||
# Don't wrap DocString nodes
|
||||
# (Unless magic.displayRootDocString, in which case we do wrap the root-level
|
||||
# docstring with st.write. This allows us to support notebook-like behavior
|
||||
# where you can have a cell with a markdown string)
|
||||
if _is_docstring_node(
|
||||
node.value, i, parent_type
|
||||
) and not _should_display_docstring_like_node_anyway(is_root):
|
||||
return None
|
||||
|
||||
# Don't wrap yield nodes
|
||||
if type(node.value) is ast.Yield or type(node.value) is ast.YieldFrom:
|
||||
return None
|
||||
|
||||
# Don't wrap await nodes
|
||||
if type(node.value) is ast.Await:
|
||||
return None
|
||||
|
||||
# If tuple, call st.write(*the_tuple). This allows us to add a comma at the end of a
|
||||
# statement to turn it into an expression that should be st-written. Ex:
|
||||
# "np.random.randn(1000, 2),"
|
||||
args = node.value.elts if type(node.value) is ast.Tuple else [node.value]
|
||||
return _build_st_write_call(args)
|
||||
|
||||
|
||||
def _is_string_constant_node(node) -> bool:
|
||||
return isinstance(node, ast.Constant) and isinstance(node.value, str)
|
||||
|
||||
|
||||
def _is_docstring_node(node, node_index, parent_type) -> bool:
|
||||
return (
|
||||
node_index == 0
|
||||
and _is_string_constant_node(node)
|
||||
and parent_type in {ast.FunctionDef, ast.AsyncFunctionDef, ast.Module}
|
||||
)
|
||||
|
||||
|
||||
def _does_file_end_in_semicolon(tree, code: str) -> bool:
|
||||
file_ends_in_semicolon = False
|
||||
|
||||
# Avoid spending time with this operation if magic.displayLastExprIfNoSemicolon is
|
||||
# not set.
|
||||
if config.get_option("magic.displayLastExprIfNoSemicolon"):
|
||||
if len(tree.body) == 0:
|
||||
return False
|
||||
|
||||
last_line_num = getattr(tree.body[-1], "end_lineno", None)
|
||||
|
||||
if last_line_num is not None:
|
||||
last_line_str: str = code.split("\n")[last_line_num - 1]
|
||||
file_ends_in_semicolon = last_line_str.strip(" ").endswith(";")
|
||||
|
||||
return file_ends_in_semicolon
|
||||
|
||||
|
||||
def _is_displayable_last_expr(
|
||||
is_root: bool, is_last_expr: bool, file_ends_in_semicolon: bool
|
||||
) -> bool:
|
||||
return (
|
||||
# This is a "displayable last expression" if...
|
||||
# ...it's actually the last expression...
|
||||
is_last_expr
|
||||
# ...in the root scope...
|
||||
and is_root
|
||||
# ...it does not end in a semicolon...
|
||||
and not file_ends_in_semicolon
|
||||
# ...and this config option is telling us to show it
|
||||
and config.get_option("magic.displayLastExprIfNoSemicolon")
|
||||
)
|
||||
|
||||
|
||||
def _should_display_docstring_like_node_anyway(is_root: bool) -> bool:
|
||||
return config.get_option("magic.displayRootDocString") and is_root
|
||||
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from streamlit.runtime.metrics_util import gather_metrics
|
||||
|
||||
|
||||
@gather_metrics("magic")
|
||||
def transparent_write(*args: Any) -> Any:
|
||||
"""The function that gets magic-ified into Streamlit apps.
|
||||
This is just st.write, but returns the arguments you passed to it.
|
||||
"""
|
||||
import streamlit as st
|
||||
|
||||
st.write(*args)
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return args
|
||||
@@ -0,0 +1,89 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os.path
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from streamlit import config
|
||||
from streamlit.runtime.scriptrunner import magic
|
||||
from streamlit.source_util import open_python_file
|
||||
|
||||
|
||||
class ScriptCache:
|
||||
"""Thread-safe cache of Python script bytecode."""
|
||||
|
||||
def __init__(self):
|
||||
# Mapping of script_path: bytecode
|
||||
self._cache: dict[str, Any] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all entries from the cache.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: SAFE. May be called on any thread.
|
||||
"""
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
|
||||
def get_bytecode(self, script_path: str) -> Any:
|
||||
"""Return the bytecode for the Python script at the given path.
|
||||
|
||||
If the bytecode is not already in the cache, the script will be
|
||||
compiled first.
|
||||
|
||||
Raises
|
||||
------
|
||||
Any Exception raised while reading or compiling the script.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: SAFE. May be called on any thread.
|
||||
"""
|
||||
|
||||
script_path = os.path.abspath(script_path)
|
||||
|
||||
with self._lock:
|
||||
bytecode = self._cache.get(script_path, None)
|
||||
if bytecode is not None:
|
||||
# Fast path: the code is already cached.
|
||||
return bytecode
|
||||
|
||||
# Populate the cache
|
||||
with open_python_file(script_path) as f:
|
||||
filebody = f.read()
|
||||
|
||||
if config.get_option("runner.magicEnabled"):
|
||||
filebody = magic.add_magic(filebody, script_path)
|
||||
|
||||
bytecode = compile( # type: ignore
|
||||
filebody,
|
||||
# Pass in the file path so it can show up in exceptions.
|
||||
script_path,
|
||||
# We're compiling entire blocks of Python, so we need "exec"
|
||||
# mode (as opposed to "eval" or "single").
|
||||
mode="exec",
|
||||
# Don't inherit any flags or "future" statements.
|
||||
flags=0,
|
||||
dont_inherit=1,
|
||||
# Use the default optimization options.
|
||||
optimize=-1,
|
||||
)
|
||||
|
||||
self._cache[script_path] = bytecode
|
||||
return bytecode
|
||||
@@ -0,0 +1,756 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import sys
|
||||
import threading
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from timeit import default_timer as timer
|
||||
from typing import TYPE_CHECKING, Callable, Final, Literal, cast
|
||||
|
||||
from blinker import Signal
|
||||
|
||||
from streamlit import config, runtime, util
|
||||
from streamlit.errors import FragmentStorageKeyError
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.proto.ClientState_pb2 import ClientState
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
from streamlit.runtime.metrics_util import (
|
||||
create_page_profile_message,
|
||||
to_microseconds,
|
||||
)
|
||||
from streamlit.runtime.pages_manager import PagesManager
|
||||
from streamlit.runtime.scriptrunner.exec_code import (
|
||||
exec_func_with_error_handling,
|
||||
modified_sys_path,
|
||||
)
|
||||
from streamlit.runtime.scriptrunner.script_cache import ScriptCache
|
||||
from streamlit.runtime.scriptrunner_utils.exceptions import (
|
||||
RerunException,
|
||||
StopException,
|
||||
)
|
||||
from streamlit.runtime.scriptrunner_utils.script_requests import (
|
||||
RerunData,
|
||||
ScriptRequests,
|
||||
ScriptRequestType,
|
||||
)
|
||||
from streamlit.runtime.scriptrunner_utils.script_run_context import (
|
||||
ScriptRunContext,
|
||||
add_script_run_ctx,
|
||||
get_script_run_ctx,
|
||||
)
|
||||
from streamlit.runtime.state import (
|
||||
SCRIPT_RUN_WITHOUT_ERRORS_KEY,
|
||||
SafeSessionState,
|
||||
SessionState,
|
||||
)
|
||||
from streamlit.source_util import page_sort_key
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.runtime.fragment import FragmentStorage
|
||||
from streamlit.runtime.scriptrunner.script_cache import ScriptCache
|
||||
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
|
||||
class ScriptRunnerEvent(Enum):
|
||||
# "Control" events. These are emitted when the ScriptRunner's state changes.
|
||||
|
||||
# The script started running.
|
||||
SCRIPT_STARTED = "SCRIPT_STARTED"
|
||||
|
||||
# The script run stopped because of a compile error.
|
||||
SCRIPT_STOPPED_WITH_COMPILE_ERROR = "SCRIPT_STOPPED_WITH_COMPILE_ERROR"
|
||||
|
||||
# The script run stopped because it ran to completion, or was
|
||||
# interrupted by the user.
|
||||
SCRIPT_STOPPED_WITH_SUCCESS = "SCRIPT_STOPPED_WITH_SUCCESS"
|
||||
|
||||
# The script run stopped in order to start a script run with newer widget state.
|
||||
SCRIPT_STOPPED_FOR_RERUN = "SCRIPT_STOPPED_FOR_RERUN"
|
||||
|
||||
# The script run corresponding to a fragment ran to completion, or was interrupted
|
||||
# by the user.
|
||||
FRAGMENT_STOPPED_WITH_SUCCESS = "FRAGMENT_STOPPED_WITH_SUCCESS"
|
||||
|
||||
# The ScriptRunner is done processing the ScriptEventQueue and
|
||||
# is shut down.
|
||||
SHUTDOWN = "SHUTDOWN"
|
||||
|
||||
# "Data" events. These are emitted when the ScriptRunner's script has
|
||||
# data to send to the frontend.
|
||||
|
||||
# The script has a ForwardMsg to send to the frontend.
|
||||
ENQUEUE_FORWARD_MSG = "ENQUEUE_FORWARD_MSG"
|
||||
|
||||
|
||||
"""
|
||||
Note [Threading]
|
||||
There are two kinds of threads in Streamlit, the main thread and script threads.
|
||||
The main thread is started by invoking the Streamlit CLI, and bootstraps the
|
||||
framework and runs the Tornado webserver.
|
||||
A script thread is created by a ScriptRunner when it starts. The script thread
|
||||
is where the ScriptRunner executes, including running the user script itself,
|
||||
processing messages to/from the frontend, and all the Streamlit library function
|
||||
calls in the user script.
|
||||
It is possible for the user script to spawn its own threads, which could call
|
||||
Streamlit functions. We restrict the ScriptRunner's execution control to the
|
||||
script thread. Calling Streamlit functions from other threads is unlikely to
|
||||
work correctly due to lack of ScriptRunContext, so we may add a guard against
|
||||
it in the future.
|
||||
"""
|
||||
|
||||
|
||||
# For projects that have a pages folder, we assume that this is a script that
|
||||
# is designed to leverage our original v1 version of multi-page apps. This
|
||||
# function will be called to run the script in lieu of the main script. This
|
||||
# function simulates the v1 setup using the modern v2 commands (st.navigation)
|
||||
def _mpa_v1(main_script_path: str):
|
||||
from pathlib import Path
|
||||
|
||||
from streamlit.commands.navigation import PageType, _navigation
|
||||
from streamlit.navigation.page import StreamlitPage
|
||||
|
||||
# Select the folder that should be used for the pages:
|
||||
MAIN_SCRIPT_PATH = Path(main_script_path).resolve()
|
||||
PAGES_FOLDER = MAIN_SCRIPT_PATH.parent / "pages"
|
||||
|
||||
# Read out the my_pages folder and create a page for every script:
|
||||
pages = PAGES_FOLDER.glob("*.py")
|
||||
pages = sorted(
|
||||
[page for page in pages if page.name.endswith(".py")], key=page_sort_key
|
||||
)
|
||||
|
||||
# Use this script as the main page and
|
||||
main_page = StreamlitPage(MAIN_SCRIPT_PATH, default=True)
|
||||
all_pages = [main_page] + [
|
||||
StreamlitPage(PAGES_FOLDER / page.name) for page in pages
|
||||
]
|
||||
# Initialize the navigation with all the pages:
|
||||
position: Literal["sidebar", "hidden"] = (
|
||||
"hidden"
|
||||
if config.get_option("client.showSidebarNavigation") is False
|
||||
else "sidebar"
|
||||
)
|
||||
page = _navigation(
|
||||
cast("list[PageType]", all_pages),
|
||||
position=position,
|
||||
expanded=False,
|
||||
)
|
||||
|
||||
if page._page != main_page._page:
|
||||
# Only run the page if it is not pointing to this script:
|
||||
page.run()
|
||||
# Finish the script execution here to only run the selected page
|
||||
raise StopException()
|
||||
|
||||
|
||||
class ScriptRunner:
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
main_script_path: str,
|
||||
session_state: SessionState,
|
||||
uploaded_file_mgr: UploadedFileManager,
|
||||
script_cache: ScriptCache,
|
||||
initial_rerun_data: RerunData,
|
||||
user_info: dict[str, str | bool | None],
|
||||
fragment_storage: FragmentStorage,
|
||||
pages_manager: PagesManager,
|
||||
):
|
||||
"""Initialize the ScriptRunner.
|
||||
|
||||
(The ScriptRunner won't start executing until start() is called.)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The AppSession's id.
|
||||
|
||||
main_script_path
|
||||
Path to our main app script.
|
||||
|
||||
session_state
|
||||
The AppSession's SessionState instance.
|
||||
|
||||
uploaded_file_mgr
|
||||
The File manager to store the data uploaded by the file_uploader widget.
|
||||
|
||||
script_cache
|
||||
A ScriptCache instance.
|
||||
|
||||
initial_rerun_data
|
||||
RerunData to initialize this ScriptRunner with.
|
||||
|
||||
user_info
|
||||
A dict that contains information about the current user. For now,
|
||||
it only contains the user's email address.
|
||||
|
||||
{
|
||||
"email": "example@example.com"
|
||||
}
|
||||
|
||||
Information about the current user is optionally provided when a
|
||||
websocket connection is initialized via the "X-Streamlit-User" header.
|
||||
|
||||
fragment_storage
|
||||
The AppSession's FragmentStorage instance.
|
||||
"""
|
||||
self._session_id = session_id
|
||||
self._main_script_path = main_script_path
|
||||
self._session_state = SafeSessionState(
|
||||
session_state, yield_callback=self._maybe_handle_execution_control_request
|
||||
)
|
||||
self._uploaded_file_mgr = uploaded_file_mgr
|
||||
self._script_cache = script_cache
|
||||
self._user_info = user_info
|
||||
self._fragment_storage = fragment_storage
|
||||
|
||||
self._pages_manager = pages_manager
|
||||
self._requests = ScriptRequests()
|
||||
self._requests.request_rerun(initial_rerun_data)
|
||||
|
||||
self.on_event = Signal(
|
||||
doc="""Emitted when a ScriptRunnerEvent occurs.
|
||||
|
||||
This signal is generally emitted on the ScriptRunner's script
|
||||
thread (which is *not* the same thread that the ScriptRunner was
|
||||
created on).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sender: ScriptRunner
|
||||
The sender of the event (this ScriptRunner).
|
||||
|
||||
event : ScriptRunnerEvent
|
||||
|
||||
forward_msg : ForwardMsg | None
|
||||
The ForwardMsg to send to the frontend. Set only for the
|
||||
ENQUEUE_FORWARD_MSG event.
|
||||
|
||||
exception : BaseException | None
|
||||
Our compile error. Set only for the
|
||||
SCRIPT_STOPPED_WITH_COMPILE_ERROR event.
|
||||
|
||||
widget_states : streamlit.proto.WidgetStates_pb2.WidgetStates | None
|
||||
The ScriptRunner's final WidgetStates. Set only for the
|
||||
SHUTDOWN event.
|
||||
"""
|
||||
)
|
||||
|
||||
# Set to true while we're executing. Used by
|
||||
# _maybe_handle_execution_control_request.
|
||||
self._execing = False
|
||||
|
||||
# This is initialized in start()
|
||||
self._script_thread: threading.Thread | None = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
def request_stop(self) -> None:
|
||||
"""Request that the ScriptRunner stop running its script and
|
||||
shut down. The ScriptRunner will handle this request when it reaches
|
||||
an interrupt point.
|
||||
|
||||
Safe to call from any thread.
|
||||
"""
|
||||
self._requests.request_stop()
|
||||
|
||||
def request_rerun(self, rerun_data: RerunData) -> bool:
|
||||
"""Request that the ScriptRunner interrupt its currently-running
|
||||
script and restart it.
|
||||
|
||||
If the ScriptRunner has been stopped, this request can't be honored:
|
||||
return False.
|
||||
|
||||
Otherwise, record the request and return True. The ScriptRunner will
|
||||
handle the rerun request as soon as it reaches an interrupt point.
|
||||
|
||||
Safe to call from any thread.
|
||||
"""
|
||||
return self._requests.request_rerun(rerun_data)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start a new thread to process the ScriptEventQueue.
|
||||
|
||||
This must be called only once.
|
||||
|
||||
"""
|
||||
if self._script_thread is not None:
|
||||
raise Exception("ScriptRunner was already started")
|
||||
|
||||
self._script_thread = threading.Thread(
|
||||
target=self._run_script_thread,
|
||||
name="ScriptRunner.scriptThread",
|
||||
)
|
||||
self._script_thread.start()
|
||||
|
||||
def _get_script_run_ctx(self) -> ScriptRunContext:
|
||||
"""Get the ScriptRunContext for the current thread.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ScriptRunContext
|
||||
The ScriptRunContext for the current thread.
|
||||
|
||||
Raises
|
||||
------
|
||||
AssertionError
|
||||
If called outside of a ScriptRunner thread.
|
||||
RuntimeError
|
||||
If there is no ScriptRunContext for the current thread.
|
||||
|
||||
"""
|
||||
assert self._is_in_script_thread()
|
||||
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is None:
|
||||
# This should never be possible on the script_runner thread.
|
||||
raise RuntimeError(
|
||||
"ScriptRunner thread has a null ScriptRunContext. "
|
||||
"Something has gone very wrong!"
|
||||
)
|
||||
return ctx
|
||||
|
||||
def _run_script_thread(self) -> None:
|
||||
"""The entry point for the script thread.
|
||||
|
||||
Processes the ScriptRequestQueue, which will at least contain the RERUN
|
||||
request that will trigger the first script-run.
|
||||
|
||||
When the ScriptRequestQueue is empty, or when a SHUTDOWN request is
|
||||
dequeued, this function will exit and its thread will terminate.
|
||||
"""
|
||||
assert self._is_in_script_thread()
|
||||
|
||||
_LOGGER.debug("Beginning script thread")
|
||||
|
||||
# Create and attach the thread's ScriptRunContext
|
||||
ctx = ScriptRunContext(
|
||||
session_id=self._session_id,
|
||||
_enqueue=self._enqueue_forward_msg,
|
||||
script_requests=self._requests,
|
||||
query_string="",
|
||||
session_state=self._session_state,
|
||||
uploaded_file_mgr=self._uploaded_file_mgr,
|
||||
main_script_path=self._main_script_path,
|
||||
user_info=self._user_info,
|
||||
gather_usage_stats=bool(config.get_option("browser.gatherUsageStats")),
|
||||
fragment_storage=self._fragment_storage,
|
||||
pages_manager=self._pages_manager,
|
||||
context_info=None,
|
||||
)
|
||||
add_script_run_ctx(threading.current_thread(), ctx)
|
||||
|
||||
request = self._requests.on_scriptrunner_ready()
|
||||
while request.type == ScriptRequestType.RERUN:
|
||||
# When the script thread starts, we'll have a pending rerun
|
||||
# request that we'll handle immediately. When the script finishes,
|
||||
# it's possible that another request has come in that we need to
|
||||
# handle, which is why we call _run_script in a loop.
|
||||
self._run_script(request.rerun_data)
|
||||
request = self._requests.on_scriptrunner_ready()
|
||||
|
||||
assert request.type == ScriptRequestType.STOP
|
||||
|
||||
# Send a SHUTDOWN event before exiting, so some state can be saved
|
||||
# for use in a future script run when not triggered by the client.
|
||||
client_state = ClientState()
|
||||
client_state.query_string = ctx.query_string
|
||||
client_state.page_script_hash = ctx.page_script_hash
|
||||
self.on_event.send(
|
||||
self, event=ScriptRunnerEvent.SHUTDOWN, client_state=client_state
|
||||
)
|
||||
|
||||
def _is_in_script_thread(self) -> bool:
|
||||
"""True if the calling function is running in the script thread."""
|
||||
return self._script_thread == threading.current_thread()
|
||||
|
||||
def _enqueue_forward_msg(self, msg: ForwardMsg) -> None:
|
||||
"""Enqueue a ForwardMsg to our browser queue.
|
||||
This private function is called by ScriptRunContext only.
|
||||
|
||||
It may be called from the script thread OR the main thread.
|
||||
"""
|
||||
# Whenever we enqueue a ForwardMsg, we also handle any pending
|
||||
# execution control request. This means that a script can be
|
||||
# cleanly interrupted and stopped inside most `st.foo` calls.
|
||||
self._maybe_handle_execution_control_request()
|
||||
|
||||
# Pass the message to our associated AppSession.
|
||||
self.on_event.send(
|
||||
self, event=ScriptRunnerEvent.ENQUEUE_FORWARD_MSG, forward_msg=msg
|
||||
)
|
||||
|
||||
def _maybe_handle_execution_control_request(self) -> None:
|
||||
"""Check our current ScriptRequestState to see if we have a
|
||||
pending STOP or RERUN request.
|
||||
|
||||
This function is called every time the app script enqueues a
|
||||
ForwardMsg, which means that most `st.foo` commands - which generally
|
||||
involve sending a ForwardMsg to the frontend - act as implicit
|
||||
yield points in the script's execution.
|
||||
"""
|
||||
if not self._is_in_script_thread():
|
||||
# We can only handle execution_control_request if we're on the
|
||||
# script execution thread. However, it's possible for deltas to
|
||||
# be enqueued (and, therefore, for this function to be called)
|
||||
# in separate threads, so we check for that here.
|
||||
return
|
||||
|
||||
if not self._execing:
|
||||
# If the _execing flag is not set, we're not actually inside
|
||||
# an exec() call. This happens when our script exec() completes,
|
||||
# we change our state to STOPPED, and a statechange-listener
|
||||
# enqueues a new ForwardEvent
|
||||
return
|
||||
|
||||
request = self._requests.on_scriptrunner_yield()
|
||||
if request is None:
|
||||
# No RERUN or STOP request.
|
||||
return
|
||||
|
||||
if request.type == ScriptRequestType.RERUN:
|
||||
raise RerunException(request.rerun_data)
|
||||
|
||||
assert request.type == ScriptRequestType.STOP
|
||||
raise StopException()
|
||||
|
||||
@contextmanager
|
||||
def _set_execing_flag(self):
|
||||
"""A context for setting the ScriptRunner._execing flag.
|
||||
|
||||
Used by _maybe_handle_execution_control_request to ensure that
|
||||
we only handle requests while we're inside an exec() call
|
||||
"""
|
||||
if self._execing:
|
||||
raise RuntimeError("Nested set_execing_flag call")
|
||||
self._execing = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._execing = False
|
||||
|
||||
def _run_script(self, rerun_data: RerunData) -> None:
|
||||
"""Run our script.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rerun_data: RerunData
|
||||
The RerunData to use.
|
||||
|
||||
"""
|
||||
|
||||
assert self._is_in_script_thread()
|
||||
|
||||
# An explicit loop instead of recursion to avoid stack overflows
|
||||
while True:
|
||||
_LOGGER.debug("Running script %s", rerun_data)
|
||||
start_time: float = timer()
|
||||
prep_time: float = 0 # This will be overwritten once preparations are done.
|
||||
|
||||
if not rerun_data.fragment_id_queue:
|
||||
# Don't clear session refs for media files if we're running a fragment.
|
||||
# Otherwise, we're likely to remove files that still have corresponding
|
||||
# download buttons/links to them present in the app, which will result
|
||||
# in a 404 should the user click on them.
|
||||
runtime.get_instance().media_file_mgr.clear_session_refs()
|
||||
|
||||
self._pages_manager.set_script_intent(
|
||||
rerun_data.page_script_hash, rerun_data.page_name
|
||||
)
|
||||
active_script = self._pages_manager.get_initial_active_script(
|
||||
rerun_data.page_script_hash, rerun_data.page_name
|
||||
)
|
||||
main_page_info = self._pages_manager.get_main_page()
|
||||
|
||||
page_script_hash = (
|
||||
active_script["page_script_hash"]
|
||||
if active_script is not None
|
||||
else main_page_info["page_script_hash"]
|
||||
)
|
||||
|
||||
ctx = self._get_script_run_ctx()
|
||||
# Clear widget state on page change. This normally happens implicitly
|
||||
# in the script run cleanup steps, but doing it explicitly ensures
|
||||
# it happens even if a script run was interrupted.
|
||||
previous_page_script_hash = ctx.page_script_hash
|
||||
if previous_page_script_hash != page_script_hash:
|
||||
# Page changed, enforce reset widget state where possible.
|
||||
# This enforcement matters when a new script thread is started
|
||||
# before the previous script run is completed (from user
|
||||
# interaction). Use the widget ids from the rerun data to
|
||||
# maintain some widget state, as the rerun data should
|
||||
# contain the latest widget ids from the frontend.
|
||||
widget_ids: set[str] = set()
|
||||
|
||||
if (
|
||||
rerun_data.widget_states is not None
|
||||
and rerun_data.widget_states.widgets is not None
|
||||
):
|
||||
widget_ids = {w.id for w in rerun_data.widget_states.widgets}
|
||||
self._session_state.on_script_finished(widget_ids)
|
||||
|
||||
fragment_ids_this_run = list(rerun_data.fragment_id_queue)
|
||||
|
||||
ctx.reset(
|
||||
query_string=rerun_data.query_string,
|
||||
page_script_hash=page_script_hash,
|
||||
fragment_ids_this_run=fragment_ids_this_run,
|
||||
context_info=rerun_data.context_info,
|
||||
)
|
||||
|
||||
self.on_event.send(
|
||||
self,
|
||||
event=ScriptRunnerEvent.SCRIPT_STARTED,
|
||||
page_script_hash=page_script_hash,
|
||||
fragment_ids_this_run=fragment_ids_this_run,
|
||||
pages=self._pages_manager.get_pages(),
|
||||
)
|
||||
|
||||
# Compile the script. Any errors thrown here will be surfaced
|
||||
# to the user via a modal dialog in the frontend, and won't result
|
||||
# in their previous script elements disappearing.
|
||||
try:
|
||||
if active_script is not None:
|
||||
script_path = active_script["script_path"]
|
||||
else:
|
||||
# page must not be found
|
||||
script_path = main_page_info["script_path"]
|
||||
|
||||
# At this point, we know that either
|
||||
# * the script corresponding to the hash requested no longer
|
||||
# exists, or
|
||||
# * we were not able to find a script with the requested page
|
||||
# name.
|
||||
# In both of these cases, we want to send a page_not_found
|
||||
# message to the frontend.
|
||||
msg = ForwardMsg()
|
||||
msg.page_not_found.page_name = rerun_data.page_name
|
||||
ctx.enqueue(msg)
|
||||
|
||||
code = self._script_cache.get_bytecode(script_path)
|
||||
|
||||
except Exception as ex:
|
||||
# We got a compile error. Send an error event and bail immediately.
|
||||
_LOGGER.exception("Script compilation error", exc_info=ex)
|
||||
self._session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY] = False
|
||||
self.on_event.send(
|
||||
self,
|
||||
event=ScriptRunnerEvent.SCRIPT_STOPPED_WITH_COMPILE_ERROR,
|
||||
exception=ex,
|
||||
)
|
||||
return
|
||||
|
||||
# If we get here, we've successfully compiled our script. The next step
|
||||
# is to run it. Errors thrown during execution will be shown to the
|
||||
# user as ExceptionElements.
|
||||
|
||||
# Create fake module. This gives us a name global namespace to
|
||||
# execute the code in.
|
||||
module = self._new_module("__main__")
|
||||
|
||||
# Install the fake module as the __main__ module. This allows
|
||||
# the pickle module to work inside the user's code, since it now
|
||||
# can know the module where the pickled objects stem from.
|
||||
# IMPORTANT: This means we can't use "if __name__ == '__main__'" in
|
||||
# our code, as it will point to the wrong module!!!
|
||||
sys.modules["__main__"] = module
|
||||
|
||||
# Add special variables to the module's globals dict.
|
||||
# Note: The following is a requirement for the CodeHasher to
|
||||
# work correctly. The CodeHasher is scoped to
|
||||
# files contained in the directory of __main__.__file__, which we
|
||||
# assume is the main script directory.
|
||||
module.__dict__["__file__"] = script_path
|
||||
|
||||
def code_to_exec(code=code, module=module, ctx=ctx, rerun_data=rerun_data):
|
||||
with (
|
||||
modified_sys_path(self._main_script_path),
|
||||
self._set_execing_flag(),
|
||||
):
|
||||
# Run callbacks for widgets whose values have changed.
|
||||
if rerun_data.widget_states is not None:
|
||||
self._session_state.on_script_will_rerun(
|
||||
rerun_data.widget_states
|
||||
)
|
||||
|
||||
ctx.on_script_start()
|
||||
|
||||
if rerun_data.fragment_id_queue:
|
||||
for fragment_id in rerun_data.fragment_id_queue:
|
||||
try:
|
||||
wrapped_fragment = self._fragment_storage.get(
|
||||
fragment_id
|
||||
)
|
||||
wrapped_fragment()
|
||||
|
||||
except FragmentStorageKeyError:
|
||||
# This can happen if the fragment_id is removed from the
|
||||
# storage before the script runner gets to it. In this
|
||||
# case, the fragment is simply skipped.
|
||||
# Also, only log an error if the fragment is not an
|
||||
# auto_rerun to avoid noise. If it is an auto_rerun, we
|
||||
# might have a race condition where the fragment_id is
|
||||
# removed but the webapp sends a rerun request before
|
||||
# the removal information has reached the web app
|
||||
# (see https://github.com/streamlit/streamlit/issues/9080).
|
||||
if not rerun_data.is_auto_rerun:
|
||||
_LOGGER.warning(
|
||||
f"Couldn't find fragment with id {fragment_id}."
|
||||
" This can happen if the fragment does not"
|
||||
" exist anymore when this request is processed,"
|
||||
" for example because a full app rerun happened"
|
||||
" that did not register the fragment."
|
||||
" Usually this doesn't happen or no action is"
|
||||
" required, so its mainly for debugging."
|
||||
)
|
||||
except (RerunException, StopException) as e:
|
||||
# The wrapped_fragment function is executed
|
||||
# inside of a exec_func_with_error_handling call, so
|
||||
# there is a correct handler for these exceptions.
|
||||
raise e
|
||||
except Exception:
|
||||
# Ignore exceptions raised by fragments here as we don't
|
||||
# want to stop the execution of other fragments. The
|
||||
# error itself is already rendered within the wrapped
|
||||
# fragment.
|
||||
pass
|
||||
|
||||
else:
|
||||
if PagesManager.uses_pages_directory:
|
||||
_mpa_v1(self._main_script_path)
|
||||
exec(code, module.__dict__)
|
||||
self._fragment_storage.clear(
|
||||
new_fragment_ids=ctx.new_fragment_ids
|
||||
)
|
||||
|
||||
self._session_state.maybe_check_serializable()
|
||||
# check for control requests, e.g. rerun requests have arrived
|
||||
self._maybe_handle_execution_control_request()
|
||||
|
||||
prep_time = timer() - start_time
|
||||
(
|
||||
_,
|
||||
run_without_errors,
|
||||
rerun_exception_data,
|
||||
premature_stop,
|
||||
uncaught_exception,
|
||||
) = exec_func_with_error_handling(code_to_exec, ctx)
|
||||
# setting the session state here triggers a yield-callback call
|
||||
# which reads self._requests and checks for rerun data
|
||||
self._session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY] = run_without_errors
|
||||
|
||||
if rerun_exception_data:
|
||||
# The handling for when a full script run or a fragment is stopped early
|
||||
# is the same, so we only have one ScriptRunnerEvent for this scenario.
|
||||
finished_event = ScriptRunnerEvent.SCRIPT_STOPPED_FOR_RERUN
|
||||
elif rerun_data.fragment_id_queue:
|
||||
finished_event = ScriptRunnerEvent.FRAGMENT_STOPPED_WITH_SUCCESS
|
||||
else:
|
||||
finished_event = ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
|
||||
|
||||
if ctx.gather_usage_stats:
|
||||
try:
|
||||
# Create and send page profile information
|
||||
ctx.enqueue(
|
||||
create_page_profile_message(
|
||||
commands=ctx.tracked_commands,
|
||||
exec_time=to_microseconds(timer() - start_time),
|
||||
prep_time=to_microseconds(prep_time),
|
||||
uncaught_exception=(
|
||||
type(uncaught_exception).__name__
|
||||
if uncaught_exception
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception as ex:
|
||||
# Always capture all exceptions since we want to make sure that
|
||||
# the telemetry never causes any issues.
|
||||
_LOGGER.debug("Failed to create page profile", exc_info=ex)
|
||||
self._on_script_finished(ctx, finished_event, premature_stop)
|
||||
|
||||
# # Use _log_if_error() to make sure we never ever ever stop running the
|
||||
# # script without meaning to.
|
||||
_log_if_error(_clean_problem_modules)
|
||||
|
||||
if rerun_exception_data is not None:
|
||||
rerun_data = rerun_exception_data
|
||||
else:
|
||||
break
|
||||
|
||||
def _on_script_finished(
|
||||
self, ctx: ScriptRunContext, event: ScriptRunnerEvent, premature_stop: bool
|
||||
) -> None:
|
||||
"""Called when our script finishes executing, even if it finished
|
||||
early with an exception. We perform post-run cleanup here.
|
||||
"""
|
||||
# Tell session_state to update itself in response
|
||||
if not premature_stop:
|
||||
self._session_state.on_script_finished(ctx.widget_ids_this_run)
|
||||
|
||||
# Signal that the script has finished. (We use SCRIPT_STOPPED_WITH_SUCCESS
|
||||
# even if we were stopped with an exception.)
|
||||
self.on_event.send(self, event=event)
|
||||
|
||||
# Remove orphaned files now that the script has run and files in use
|
||||
# are marked as active.
|
||||
runtime.get_instance().media_file_mgr.remove_orphaned_files()
|
||||
|
||||
# Force garbage collection to run, to help avoid memory use building up
|
||||
# This is usually not an issue, but sometimes GC takes time to kick in and
|
||||
# causes apps to go over resource limits, and forcing it to run between
|
||||
# script runs is low cost, since we aren't doing much work anyway.
|
||||
if config.get_option("runner.postScriptGC"):
|
||||
gc.collect(2)
|
||||
|
||||
def _new_module(self, name: str) -> types.ModuleType:
|
||||
"""Create a new module with the given name."""
|
||||
return types.ModuleType(name)
|
||||
|
||||
|
||||
def _clean_problem_modules() -> None:
|
||||
"""Some modules are stateful, so we have to clear their state."""
|
||||
|
||||
if "keras" in sys.modules:
|
||||
try:
|
||||
keras = sys.modules["keras"]
|
||||
keras.backend.clear_session()
|
||||
except Exception:
|
||||
# We don't want to crash the app if we can't clear the Keras session.
|
||||
pass
|
||||
|
||||
if "matplotlib.pyplot" in sys.modules:
|
||||
try:
|
||||
plt = sys.modules["matplotlib.pyplot"]
|
||||
plt.close("all")
|
||||
except Exception:
|
||||
# We don't want to crash the app if we can't close matplotlib
|
||||
pass
|
||||
|
||||
|
||||
# The reason this is not a decorator is because we want to make it clear at the
|
||||
# calling location that this function is being used.
|
||||
def _log_if_error(fn: Callable[[], None]) -> None:
|
||||
try:
|
||||
fn()
|
||||
except Exception as e:
|
||||
_LOGGER.warning(e)
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The modules in this package are separated from
|
||||
the scriptrunner-package, because they are more or less
|
||||
standalone and other modules import them quite frequently.
|
||||
This separation helps us to remove dependency cycles.
|
||||
"""
|
||||
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from streamlit.runtime.scriptrunner_utils.script_requests import RerunData
|
||||
from streamlit.util import repr_
|
||||
|
||||
|
||||
# We inherit from BaseException to avoid being caught by user code.
|
||||
# For example, having it inherit from Exception might make st.rerun not
|
||||
# work in a try/except block.
|
||||
class ScriptControlException(BaseException): # NOSONAR
|
||||
"""Base exception for ScriptRunner."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StopException(ScriptControlException):
|
||||
"""Silently stop the execution of the user's script."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RerunException(ScriptControlException):
|
||||
"""Silently stop and rerun the user's script."""
|
||||
|
||||
def __init__(self, rerun_data: RerunData):
|
||||
"""Construct a RerunException.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rerun_data : RerunData
|
||||
The RerunData that should be used to rerun the script
|
||||
"""
|
||||
self.rerun_data = rerun_data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr_(self)
|
||||
@@ -0,0 +1,305 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from dataclasses import dataclass, field, replace
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from streamlit import util
|
||||
from streamlit.proto.Common_pb2 import ChatInputValue as ChatInputValueProto
|
||||
from streamlit.proto.WidgetStates_pb2 import WidgetState, WidgetStates
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.proto.ClientState_pb2 import ContextInfo
|
||||
|
||||
|
||||
class ScriptRequestType(Enum):
|
||||
# The ScriptRunner should continue running its script.
|
||||
CONTINUE = "CONTINUE"
|
||||
|
||||
# If the script is running, it should be stopped as soon
|
||||
# as the ScriptRunner reaches an interrupt point.
|
||||
# This is a terminal state.
|
||||
STOP = "STOP"
|
||||
|
||||
# A script rerun has been requested. The ScriptRunner should
|
||||
# handle this request as soon as it reaches an interrupt point.
|
||||
RERUN = "RERUN"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RerunData:
|
||||
"""Data attached to RERUN requests. Immutable."""
|
||||
|
||||
query_string: str = ""
|
||||
widget_states: WidgetStates | None = None
|
||||
page_script_hash: str = ""
|
||||
page_name: str = ""
|
||||
|
||||
# A single fragment_id to append to fragment_id_queue.
|
||||
fragment_id: str | None = None
|
||||
# The queue of fragment_ids waiting to be run.
|
||||
fragment_id_queue: list[str] = field(default_factory=list)
|
||||
is_fragment_scoped_rerun: bool = False
|
||||
# set to true when a script is rerun by the fragment auto-rerun mechanism
|
||||
is_auto_rerun: bool = False
|
||||
# context_info is used to store information from the user browser (e.g. timezone)
|
||||
context_info: ContextInfo | None = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScriptRequest:
|
||||
"""A STOP or RERUN request and associated data."""
|
||||
|
||||
type: ScriptRequestType
|
||||
_rerun_data: RerunData | None = None
|
||||
|
||||
@property
|
||||
def rerun_data(self) -> RerunData:
|
||||
if self.type is not ScriptRequestType.RERUN:
|
||||
raise RuntimeError("RerunData is only set for RERUN requests.")
|
||||
return cast("RerunData", self._rerun_data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
|
||||
def _fragment_run_should_not_preempt_script(
|
||||
fragment_id_queue: list[str],
|
||||
is_fragment_scoped_rerun: bool,
|
||||
) -> bool:
|
||||
"""Returns whether the currently running script should be preempted due to a
|
||||
fragment rerun.
|
||||
|
||||
Reruns corresponding to fragment runs that weren't caused by calls to
|
||||
`st.rerun(scope="fragment")` should *not* cancel the current script run
|
||||
as doing so will affect elements outside of the fragment.
|
||||
"""
|
||||
return bool(fragment_id_queue) and not is_fragment_scoped_rerun
|
||||
|
||||
|
||||
def _coalesce_widget_states(
|
||||
old_states: WidgetStates | None, new_states: WidgetStates | None
|
||||
) -> WidgetStates | None:
|
||||
"""Coalesce an older WidgetStates into a newer one, and return a new
|
||||
WidgetStates containing the result.
|
||||
|
||||
For most widget values, we just take the latest version.
|
||||
|
||||
However, any trigger_values (which are set by buttons) that are True in
|
||||
`old_states` will be set to True in the coalesced result, so that button
|
||||
presses don't go missing.
|
||||
"""
|
||||
if not old_states and not new_states:
|
||||
return None
|
||||
elif not old_states:
|
||||
return new_states
|
||||
elif not new_states:
|
||||
return old_states
|
||||
|
||||
states_by_id: dict[str, WidgetState] = {
|
||||
wstate.id: wstate for wstate in new_states.widgets
|
||||
}
|
||||
|
||||
trigger_value_types = [
|
||||
("trigger_value", False),
|
||||
("chat_input_value", ChatInputValueProto(data=None)),
|
||||
]
|
||||
for old_state in old_states.widgets:
|
||||
for trigger_value_type, unset_value in trigger_value_types:
|
||||
if (
|
||||
old_state.WhichOneof("value") == trigger_value_type
|
||||
and getattr(old_state, trigger_value_type) != unset_value
|
||||
):
|
||||
new_trigger_val = states_by_id.get(old_state.id)
|
||||
# It should nearly always be the case that new_trigger_val is None
|
||||
# here as trigger values are deleted from the client's WidgetStateManager
|
||||
# as soon as a rerun_script BackMsg is sent to the server. Since it's
|
||||
# impossible to test that the client sends us state in the expected
|
||||
# format in a unit test, we test for this behavior in
|
||||
# e2e_playwright/test_fragment_queue_test.py
|
||||
if not new_trigger_val or (
|
||||
# Ensure the corresponding new_state is also a trigger;
|
||||
# otherwise, a widget that was previously a button/chat_input but no
|
||||
# longer is could get a bad value.
|
||||
new_trigger_val.WhichOneof("value") == trigger_value_type
|
||||
# We only want to take the value of old_state if new_trigger_val is
|
||||
# unset as the old value may be stale if a newer one was entered.
|
||||
and getattr(new_trigger_val, trigger_value_type) == unset_value
|
||||
):
|
||||
states_by_id[old_state.id] = old_state
|
||||
|
||||
coalesced = WidgetStates()
|
||||
coalesced.widgets.extend(states_by_id.values())
|
||||
|
||||
return coalesced
|
||||
|
||||
|
||||
class ScriptRequests:
|
||||
"""An interface for communicating with a ScriptRunner. Thread-safe.
|
||||
|
||||
AppSession makes requests of a ScriptRunner through this class, and
|
||||
ScriptRunner handles those requests.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self._state = ScriptRequestType.CONTINUE
|
||||
self._rerun_data = RerunData()
|
||||
|
||||
def request_stop(self) -> None:
|
||||
"""Request that the ScriptRunner stop running. A stopped ScriptRunner
|
||||
can't be used anymore. STOP requests succeed unconditionally.
|
||||
"""
|
||||
with self._lock:
|
||||
self._state = ScriptRequestType.STOP
|
||||
|
||||
def request_rerun(self, new_data: RerunData) -> bool:
|
||||
"""Request that the ScriptRunner rerun its script.
|
||||
|
||||
If the ScriptRunner has been stopped, this request can't be honored:
|
||||
return False.
|
||||
|
||||
Otherwise, record the request and return True. The ScriptRunner will
|
||||
handle the rerun request as soon as it reaches an interrupt point.
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
if self._state == ScriptRequestType.STOP:
|
||||
# We can't rerun after being stopped.
|
||||
return False
|
||||
|
||||
if self._state == ScriptRequestType.CONTINUE:
|
||||
# The script is currently running, and we haven't received a request to
|
||||
# rerun it as of yet. We can handle a rerun request unconditionally so
|
||||
# just change self._state and set self._rerun_data.
|
||||
self._state = ScriptRequestType.RERUN
|
||||
|
||||
# Convert from a single fragment_id into fragment_id_queue.
|
||||
if new_data.fragment_id:
|
||||
new_data = replace(
|
||||
new_data,
|
||||
fragment_id=None,
|
||||
fragment_id_queue=[new_data.fragment_id],
|
||||
)
|
||||
|
||||
self._rerun_data = new_data
|
||||
return True
|
||||
|
||||
if self._state == ScriptRequestType.RERUN:
|
||||
# We already have an existing Rerun request, so we can coalesce the new
|
||||
# rerun request into the existing one.
|
||||
|
||||
coalesced_states = _coalesce_widget_states(
|
||||
self._rerun_data.widget_states, new_data.widget_states
|
||||
)
|
||||
|
||||
if new_data.fragment_id:
|
||||
# This RERUN request corresponds to a new fragment run. We append
|
||||
# the new fragment ID to the end of the current fragment_id_queue if
|
||||
# it isn't already contained in it.
|
||||
fragment_id_queue = [*self._rerun_data.fragment_id_queue]
|
||||
|
||||
if new_data.fragment_id not in fragment_id_queue:
|
||||
fragment_id_queue.append(new_data.fragment_id)
|
||||
elif new_data.fragment_id_queue:
|
||||
# new_data contains a new fragment_id_queue, so we just use it.
|
||||
fragment_id_queue = new_data.fragment_id_queue
|
||||
else:
|
||||
# Otherwise, this is a request to rerun the full script, so we want
|
||||
# to clear out any fragments we have queued to run since they'll all
|
||||
# be run with the full script anyway.
|
||||
fragment_id_queue = []
|
||||
|
||||
self._rerun_data = RerunData(
|
||||
query_string=new_data.query_string,
|
||||
widget_states=coalesced_states,
|
||||
page_script_hash=new_data.page_script_hash,
|
||||
page_name=new_data.page_name,
|
||||
fragment_id_queue=fragment_id_queue,
|
||||
is_fragment_scoped_rerun=new_data.is_fragment_scoped_rerun,
|
||||
is_auto_rerun=new_data.is_auto_rerun,
|
||||
context_info=new_data.context_info,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
# We'll never get here
|
||||
raise RuntimeError(f"Unrecognized ScriptRunnerState: {self._state}")
|
||||
|
||||
def on_scriptrunner_yield(self) -> ScriptRequest | None:
|
||||
"""Called by the ScriptRunner when it's at a yield point.
|
||||
|
||||
If we have no request or a RERUN request corresponding to one or more fragments
|
||||
(that is not a fragment-scoped rerun), return None.
|
||||
|
||||
If we have a (full script or fragment-scoped) RERUN request, return the request
|
||||
and set our internal state to CONTINUE.
|
||||
|
||||
If we have a STOP request, return the request and remain stopped.
|
||||
"""
|
||||
if self._state == ScriptRequestType.CONTINUE or (
|
||||
self._state == ScriptRequestType.RERUN
|
||||
and _fragment_run_should_not_preempt_script(
|
||||
self._rerun_data.fragment_id_queue,
|
||||
self._rerun_data.is_fragment_scoped_rerun,
|
||||
)
|
||||
):
|
||||
# We avoid taking the lock in the common cases described above. If a STOP or
|
||||
# preempting RERUN request is received after we've taken this code path, it
|
||||
# will be handled at the next `on_scriptrunner_yield`, or when
|
||||
# `on_scriptrunner_ready` is called.
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
if self._state == ScriptRequestType.RERUN:
|
||||
# We already made this check in the fast-path above but need to do so
|
||||
# again in case our state changed while we were waiting on the lock.
|
||||
if _fragment_run_should_not_preempt_script(
|
||||
self._rerun_data.fragment_id_queue,
|
||||
self._rerun_data.is_fragment_scoped_rerun,
|
||||
):
|
||||
return None
|
||||
|
||||
self._state = ScriptRequestType.CONTINUE
|
||||
return ScriptRequest(ScriptRequestType.RERUN, self._rerun_data)
|
||||
|
||||
assert self._state == ScriptRequestType.STOP
|
||||
return ScriptRequest(ScriptRequestType.STOP)
|
||||
|
||||
def on_scriptrunner_ready(self) -> ScriptRequest:
|
||||
"""Called by the ScriptRunner when it's about to run its script for
|
||||
the first time, and also after its script has successfully completed.
|
||||
|
||||
If we have a RERUN request, return the request and set
|
||||
our internal state to CONTINUE.
|
||||
|
||||
If we have a STOP request or no request, set our internal state
|
||||
to STOP.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._state == ScriptRequestType.RERUN:
|
||||
self._state = ScriptRequestType.CONTINUE
|
||||
return ScriptRequest(ScriptRequestType.RERUN, self._rerun_data)
|
||||
|
||||
# If we don't have a rerun request, unconditionally change our
|
||||
# state to STOP.
|
||||
self._state = ScriptRequestType.STOP
|
||||
return ScriptRequest(ScriptRequestType.STOP)
|
||||
@@ -0,0 +1,288 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import contextvars
|
||||
import threading
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Callable,
|
||||
Final,
|
||||
Union,
|
||||
)
|
||||
from urllib import parse
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from streamlit.errors import (
|
||||
NoSessionContext,
|
||||
StreamlitAPIException,
|
||||
StreamlitSetPageConfigMustBeFirstCommandError,
|
||||
)
|
||||
from streamlit.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from streamlit.cursor import RunningCursor
|
||||
from streamlit.proto.ClientState_pb2 import ContextInfo
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
from streamlit.proto.PageProfile_pb2 import Command
|
||||
from streamlit.runtime.fragment import FragmentStorage
|
||||
from streamlit.runtime.pages_manager import PagesManager
|
||||
from streamlit.runtime.scriptrunner_utils.script_requests import ScriptRequests
|
||||
from streamlit.runtime.state import SafeSessionState
|
||||
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
UserInfo: TypeAlias = dict[str, Union[str, bool, None]]
|
||||
|
||||
|
||||
# If true, it indicates that we are in a cached function that disallows the usage of
|
||||
# widgets. Using contextvars to be thread-safe.
|
||||
in_cached_function: contextvars.ContextVar[bool] = contextvars.ContextVar(
|
||||
"in_cached_function", default=False
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptRunContext:
|
||||
"""A context object that contains data for a "script run" - that is,
|
||||
data that's scoped to a single ScriptRunner execution (and therefore also
|
||||
scoped to a single connected "session").
|
||||
|
||||
ScriptRunContext is used internally by virtually every `st.foo()` function.
|
||||
It is accessed only from the script thread that's created by ScriptRunner,
|
||||
or from app-created helper threads that have been "attached" to the
|
||||
ScriptRunContext via `add_script_run_ctx`.
|
||||
|
||||
Streamlit code typically retrieves the active ScriptRunContext via the
|
||||
`get_script_run_ctx` function.
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
_enqueue: Callable[[ForwardMsg], None]
|
||||
query_string: str
|
||||
session_state: SafeSessionState
|
||||
uploaded_file_mgr: UploadedFileManager
|
||||
main_script_path: str
|
||||
user_info: UserInfo
|
||||
fragment_storage: FragmentStorage
|
||||
pages_manager: PagesManager
|
||||
|
||||
context_info: ContextInfo | None = None
|
||||
gather_usage_stats: bool = False
|
||||
command_tracking_deactivated: bool = False
|
||||
tracked_commands: list[Command] = field(default_factory=list)
|
||||
tracked_commands_counter: Counter[str] = field(default_factory=collections.Counter)
|
||||
_set_page_config_allowed: bool = True
|
||||
_has_script_started: bool = False
|
||||
widget_ids_this_run: set[str] = field(default_factory=set)
|
||||
widget_user_keys_this_run: set[str] = field(default_factory=set)
|
||||
form_ids_this_run: set[str] = field(default_factory=set)
|
||||
cursors: dict[int, RunningCursor] = field(default_factory=dict)
|
||||
script_requests: ScriptRequests | None = None
|
||||
current_fragment_id: str | None = None
|
||||
fragment_ids_this_run: list[str] | None = None
|
||||
new_fragment_ids: set[str] = field(default_factory=set)
|
||||
_active_script_hash: str = ""
|
||||
# we allow only one dialog to be open at the same time
|
||||
has_dialog_opened: bool = False
|
||||
|
||||
# TODO(willhuang1997): Remove this variable when experimental query params are removed
|
||||
_experimental_query_params_used = False
|
||||
_production_query_params_used = False
|
||||
|
||||
@property
|
||||
def page_script_hash(self):
|
||||
return self.pages_manager.current_page_script_hash
|
||||
|
||||
@property
|
||||
def active_script_hash(self):
|
||||
return self._active_script_hash
|
||||
|
||||
@property
|
||||
def main_script_parent(self) -> Path:
|
||||
return self.pages_manager.main_script_parent
|
||||
|
||||
@contextlib.contextmanager
|
||||
def run_with_active_hash(self, page_hash: str):
|
||||
original_page_hash = self._active_script_hash
|
||||
self._active_script_hash = page_hash
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# in the event of any exception, ensure we set the active hash back
|
||||
self._active_script_hash = original_page_hash
|
||||
|
||||
def set_mpa_v2_page(self, page_script_hash: str):
|
||||
self._active_script_hash = self.pages_manager.main_script_hash
|
||||
self.pages_manager.set_current_page_script_hash(page_script_hash)
|
||||
|
||||
def reset(
|
||||
self,
|
||||
query_string: str = "",
|
||||
page_script_hash: str = "",
|
||||
fragment_ids_this_run: list[str] | None = None,
|
||||
context_info: ContextInfo | None = None,
|
||||
) -> None:
|
||||
self.cursors = {}
|
||||
self.widget_ids_this_run = set()
|
||||
self.widget_user_keys_this_run = set()
|
||||
self.form_ids_this_run = set()
|
||||
self.query_string = query_string
|
||||
self.context_info = context_info
|
||||
self.pages_manager.set_current_page_script_hash(page_script_hash)
|
||||
self._active_script_hash = self.pages_manager.main_script_hash
|
||||
# Permit set_page_config when the ScriptRunContext is reused on a rerun
|
||||
self._set_page_config_allowed = True
|
||||
self._has_script_started = False
|
||||
self.command_tracking_deactivated: bool = False
|
||||
self.tracked_commands = []
|
||||
self.tracked_commands_counter = collections.Counter()
|
||||
self.current_fragment_id = None
|
||||
self.current_fragment_delta_path: list[int] = []
|
||||
self.fragment_ids_this_run = fragment_ids_this_run
|
||||
self.new_fragment_ids = set()
|
||||
self.has_dialog_opened = False
|
||||
in_cached_function.set(False)
|
||||
|
||||
parsed_query_params = parse.parse_qs(query_string, keep_blank_values=True)
|
||||
with self.session_state.query_params() as qp:
|
||||
qp.clear_with_no_forward_msg()
|
||||
for key, val in parsed_query_params.items():
|
||||
if len(val) == 0:
|
||||
qp.set_with_no_forward_msg(key, val="")
|
||||
elif len(val) == 1:
|
||||
qp.set_with_no_forward_msg(key, val=val[-1])
|
||||
else:
|
||||
qp.set_with_no_forward_msg(key, val)
|
||||
|
||||
def on_script_start(self) -> None:
|
||||
self._has_script_started = True
|
||||
|
||||
def enqueue(self, msg: ForwardMsg) -> None:
|
||||
"""Enqueue a ForwardMsg for this context's session."""
|
||||
if msg.HasField("page_config_changed") and not self._set_page_config_allowed:
|
||||
raise StreamlitSetPageConfigMustBeFirstCommandError()
|
||||
|
||||
# We want to disallow set_page config if one of the following occurs:
|
||||
# - set_page_config was called on this message
|
||||
# - The script has already started and a different st call occurs (a delta)
|
||||
if msg.HasField("page_config_changed") or (
|
||||
msg.HasField("delta") and self._has_script_started
|
||||
):
|
||||
self._set_page_config_allowed = False
|
||||
|
||||
msg.metadata.active_script_hash = self.active_script_hash
|
||||
|
||||
# Pass the message up to our associated ScriptRunner.
|
||||
self._enqueue(msg)
|
||||
|
||||
def ensure_single_query_api_used(self):
|
||||
if self._experimental_query_params_used and self._production_query_params_used:
|
||||
raise StreamlitAPIException(
|
||||
"Using `st.query_params` together with either `st.experimental_get_query_params` "
|
||||
"or `st.experimental_set_query_params` is not supported. Please convert your app "
|
||||
"to only use `st.query_params`"
|
||||
)
|
||||
|
||||
def mark_experimental_query_params_used(self):
|
||||
self._experimental_query_params_used = True
|
||||
self.ensure_single_query_api_used()
|
||||
|
||||
def mark_production_query_params_used(self):
|
||||
self._production_query_params_used = True
|
||||
self.ensure_single_query_api_used()
|
||||
|
||||
|
||||
SCRIPT_RUN_CONTEXT_ATTR_NAME: Final = "streamlit_script_run_ctx"
|
||||
|
||||
|
||||
def add_script_run_ctx(
|
||||
thread: threading.Thread | None = None, ctx: ScriptRunContext | None = None
|
||||
):
|
||||
"""Adds the current ScriptRunContext to a newly-created thread.
|
||||
|
||||
This should be called from this thread's parent thread,
|
||||
before the new thread starts.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
thread : threading.Thread
|
||||
The thread to attach the current ScriptRunContext to.
|
||||
ctx : ScriptRunContext or None
|
||||
The ScriptRunContext to add, or None to use the current thread's
|
||||
ScriptRunContext.
|
||||
|
||||
Returns
|
||||
-------
|
||||
threading.Thread
|
||||
The same thread that was passed in, for chaining.
|
||||
|
||||
"""
|
||||
if thread is None:
|
||||
thread = threading.current_thread()
|
||||
if ctx is None:
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is not None:
|
||||
setattr(thread, SCRIPT_RUN_CONTEXT_ATTR_NAME, ctx)
|
||||
return thread
|
||||
|
||||
|
||||
def get_script_run_ctx(suppress_warning: bool = False) -> ScriptRunContext | None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
suppress_warning : bool
|
||||
If True, don't log a warning if there's no ScriptRunContext.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ScriptRunContext | None
|
||||
The current thread's ScriptRunContext, or None if it doesn't have one.
|
||||
|
||||
"""
|
||||
thread = threading.current_thread()
|
||||
ctx: ScriptRunContext | None = getattr(thread, SCRIPT_RUN_CONTEXT_ATTR_NAME, None)
|
||||
if ctx is None and not suppress_warning:
|
||||
# Only warn about a missing ScriptRunContext if suppress_warning is False, and
|
||||
# we were started via `streamlit run`. Otherwise, the user is likely running a
|
||||
# script "bare", and doesn't need to be warned about streamlit
|
||||
# bits that are irrelevant when not connected to a session.
|
||||
_LOGGER.warning(
|
||||
"Thread '%s': missing ScriptRunContext! This warning can be ignored when "
|
||||
"running in bare mode.",
|
||||
thread.name,
|
||||
)
|
||||
|
||||
return ctx
|
||||
|
||||
|
||||
def enqueue_message(msg: ForwardMsg) -> None:
|
||||
"""Enqueues a ForwardMsg proto to send to the app."""
|
||||
ctx = get_script_run_ctx()
|
||||
|
||||
if ctx is None:
|
||||
raise NoSessionContext()
|
||||
|
||||
if ctx.current_fragment_id and msg.WhichOneof("type") == "delta":
|
||||
msg.delta.fragment_id = ctx.current_fragment_id
|
||||
|
||||
ctx.enqueue(msg)
|
||||
531
myenv/lib/python3.11/site-packages/streamlit/runtime/secrets.py
Normal file
531
myenv/lib/python3.11/site-packages/streamlit/runtime/secrets.py
Normal file
@@ -0,0 +1,531 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import threading
|
||||
from collections.abc import ItemsView, Iterator, KeysView, Mapping, ValuesView
|
||||
from copy import deepcopy
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Final,
|
||||
NoReturn,
|
||||
)
|
||||
|
||||
from blinker import Signal
|
||||
|
||||
import streamlit as st
|
||||
import streamlit.watcher.path_watcher
|
||||
from streamlit import runtime
|
||||
from streamlit.errors import StreamlitSecretNotFoundError
|
||||
from streamlit.logger import get_logger
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
|
||||
class SecretErrorMessages:
|
||||
"""SecretErrorMessages stores all error messages we use for secrets to allow customization for different environments.
|
||||
For example Streamlit Cloud can customize the message to be different than the open source.
|
||||
|
||||
For internal use, may change in future releases without notice.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.missing_attr_message = lambda attr_name: (
|
||||
f'st.secrets has no attribute "{attr_name}". '
|
||||
f"Did you forget to add it to secrets.toml, mount it to secret directory, or the app settings on Streamlit Cloud? "
|
||||
f"More info: https://docs.streamlit.io/deploy/streamlit-community-cloud/deploy-your-app/secrets-management"
|
||||
)
|
||||
self.missing_key_message = lambda key: (
|
||||
f'st.secrets has no key "{key}". '
|
||||
f"Did you forget to add it to secrets.toml, mount it to secret directory, or the app settings on Streamlit Cloud? "
|
||||
f"More info: https://docs.streamlit.io/deploy/streamlit-community-cloud/deploy-your-app/secrets-management"
|
||||
)
|
||||
self.no_secrets_found = lambda file_paths: (
|
||||
f"No secrets found. Valid paths for a secrets.toml file or secret directories are: {', '.join(file_paths)}"
|
||||
)
|
||||
self.error_parsing_file_at_path = (
|
||||
lambda path, ex: f"Error parsing secrets file at {path}: {ex}"
|
||||
)
|
||||
self.subfolder_path_is_not_a_folder = lambda sub_folder_path: (
|
||||
f"{sub_folder_path} is not a folder. "
|
||||
"To use directory based secrets, mount every secret in a subfolder under the secret directory"
|
||||
)
|
||||
self.invalid_secret_path = lambda path: (
|
||||
f"Invalid secrets path: {path}: path is not a .toml file or a directory"
|
||||
)
|
||||
|
||||
def set_missing_attr_message(self, message: Callable[[str], str]) -> None:
|
||||
"""Set the missing attribute error message."""
|
||||
self.missing_attr_message = message
|
||||
|
||||
def set_missing_key_message(self, message: Callable[[str], str]) -> None:
|
||||
"""Set the missing key error message."""
|
||||
self.missing_key_message = message
|
||||
|
||||
def set_no_secrets_found_message(self, message: Callable[[list[str]], str]) -> None:
|
||||
"""Set the no secrets found error message."""
|
||||
self.no_secrets_found = message
|
||||
|
||||
def set_error_parsing_file_at_path_message(
|
||||
self, message: Callable[[str, Exception], str]
|
||||
) -> None:
|
||||
"""Set the error parsing file at path error message."""
|
||||
self.error_parsing_file_at_path = message
|
||||
|
||||
def set_subfolder_path_is_not_a_folder_message(
|
||||
self, message: Callable[[str], str]
|
||||
) -> None:
|
||||
"""Set the subfolder path is not a folder error message."""
|
||||
self.subfolder_path_is_not_a_folder = message
|
||||
|
||||
def set_invalid_secret_path_message(self, message: Callable[[str], str]) -> None:
|
||||
"""Set the invalid secret path error message."""
|
||||
self.invalid_secret_path = message
|
||||
|
||||
def get_missing_attr_message(self, attr_name: str) -> str:
|
||||
"""Get the missing attribute error message."""
|
||||
return self.missing_attr_message(attr_name)
|
||||
|
||||
def get_missing_key_message(self, key: str) -> str:
|
||||
"""Get the missing key error message."""
|
||||
return self.missing_key_message(key)
|
||||
|
||||
def get_no_secrets_found_message(self, file_paths: list[str]) -> str:
|
||||
"""Get the no secrets found error message."""
|
||||
return self.no_secrets_found(file_paths)
|
||||
|
||||
def get_error_parsing_file_at_path_message(self, path: str, ex: Exception) -> str:
|
||||
"""Get the error parsing file at path error message."""
|
||||
return self.error_parsing_file_at_path(path, ex)
|
||||
|
||||
def get_subfolder_path_is_not_a_folder_message(self, sub_folder_path: str) -> str:
|
||||
"""Get the subfolder path is not a folder error message."""
|
||||
return self.subfolder_path_is_not_a_folder(sub_folder_path)
|
||||
|
||||
def get_invalid_secret_path_message(self, path: str) -> str:
|
||||
"""Get the invalid secret path error message."""
|
||||
return self.invalid_secret_path(path)
|
||||
|
||||
|
||||
secret_error_messages_singleton: Final = SecretErrorMessages()
|
||||
|
||||
|
||||
def _convert_to_dict(obj: Mapping[str, Any] | AttrDict) -> dict[str, Any]:
|
||||
"""Convert Mapping or AttrDict objects to dictionaries."""
|
||||
if isinstance(obj, AttrDict):
|
||||
return obj.to_dict()
|
||||
return {k: v.to_dict() if isinstance(v, AttrDict) else v for k, v in obj.items()}
|
||||
|
||||
|
||||
def _missing_attr_error_message(attr_name: str) -> str:
|
||||
return secret_error_messages_singleton.get_missing_attr_message(attr_name)
|
||||
|
||||
|
||||
def _missing_key_error_message(key: str) -> str:
|
||||
return secret_error_messages_singleton.get_missing_key_message(key)
|
||||
|
||||
|
||||
class AttrDict(Mapping[str, Any]):
|
||||
"""We use AttrDict to wrap up dictionary values from secrets
|
||||
to provide dot access to nested secrets.
|
||||
"""
|
||||
|
||||
def __init__(self, value):
|
||||
self.__dict__["__nested_secrets__"] = dict(value)
|
||||
|
||||
@staticmethod
|
||||
def _maybe_wrap_in_attr_dict(value) -> Any:
|
||||
if not isinstance(value, Mapping):
|
||||
return value
|
||||
else:
|
||||
return AttrDict(value)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.__nested_secrets__)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self.__nested_secrets__)
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
try:
|
||||
value = self.__nested_secrets__[key]
|
||||
return self._maybe_wrap_in_attr_dict(value)
|
||||
except KeyError:
|
||||
raise KeyError(_missing_key_error_message(key))
|
||||
|
||||
def __getattr__(self, attr_name: str) -> Any:
|
||||
try:
|
||||
value = self.__nested_secrets__[attr_name]
|
||||
return self._maybe_wrap_in_attr_dict(value)
|
||||
except KeyError:
|
||||
raise AttributeError(_missing_attr_error_message(attr_name))
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.__nested_secrets__)
|
||||
|
||||
def __setitem__(self, key, value) -> NoReturn:
|
||||
raise TypeError("Secrets does not support item assignment.")
|
||||
|
||||
def __setattr__(self, key, value) -> NoReturn:
|
||||
raise TypeError("Secrets does not support attribute assignment.")
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return deepcopy(self.__nested_secrets__)
|
||||
|
||||
|
||||
class Secrets(Mapping[str, Any]):
|
||||
"""A dict-like class that stores secrets.
|
||||
Parses secrets.toml on-demand. Cannot be externally mutated.
|
||||
|
||||
Safe to use from multiple threads.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Our secrets dict.
|
||||
self._secrets: Mapping[str, Any] | None = None
|
||||
self._lock = threading.RLock()
|
||||
self._file_watchers_installed = False
|
||||
|
||||
self.file_change_listener = Signal(
|
||||
doc="Emitted when a `secrets.toml` file has been changed."
|
||||
)
|
||||
|
||||
def load_if_toml_exists(self) -> bool:
|
||||
"""Load secrets.toml files from disk if they exists. If none exist,
|
||||
no exception will be raised. (If a file exists but is malformed,
|
||||
an exception *will* be raised.).
|
||||
|
||||
Returns True if a secrets.toml file was successfully parsed, False otherwise.
|
||||
|
||||
Thread-safe.
|
||||
"""
|
||||
try:
|
||||
self._parse()
|
||||
|
||||
return True
|
||||
except StreamlitSecretNotFoundError:
|
||||
# No secrets.toml files exist. That's fine.
|
||||
return False
|
||||
|
||||
def set_suppress_print_error_on_exception(
|
||||
self, suppress_print_error_on_exception: bool
|
||||
) -> None:
|
||||
"""Left in place for compatibility with integrations until integration
|
||||
code can be updated.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _reset(self) -> None:
|
||||
"""Clear the secrets dictionary and remove any secrets that were
|
||||
added to os.environ.
|
||||
|
||||
Thread-safe.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._secrets is None:
|
||||
return
|
||||
|
||||
for k, v in self._secrets.items():
|
||||
self._maybe_delete_environment_variable(k, v)
|
||||
self._secrets = None
|
||||
|
||||
def _parse_toml_file(self, path: str) -> tuple[Mapping[str, Any], bool]:
|
||||
"""Parse a TOML file and return the secrets as a dictionary."""
|
||||
secrets = {}
|
||||
found_secrets_file = False
|
||||
|
||||
try:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
secrets_file_str = f.read()
|
||||
|
||||
found_secrets_file = True
|
||||
except FileNotFoundError:
|
||||
# the default config for secrets contains two paths. It's likely one of will not have secrets file.
|
||||
return {}, False
|
||||
|
||||
try:
|
||||
import toml
|
||||
|
||||
secrets.update(toml.loads(secrets_file_str))
|
||||
except (TypeError, toml.TomlDecodeError) as ex:
|
||||
msg = (
|
||||
secret_error_messages_singleton.get_error_parsing_file_at_path_message(
|
||||
path, ex
|
||||
)
|
||||
)
|
||||
raise StreamlitSecretNotFoundError(msg) from ex
|
||||
|
||||
return secrets, found_secrets_file
|
||||
|
||||
def _parse_directory(self, path: str) -> tuple[Mapping[str, Any], bool]:
|
||||
"""Parse a directory for secrets. Directory style can be used to support Kubernetes secrets that are mounted to folders.
|
||||
|
||||
Example structure:
|
||||
- top_level_secret_folder
|
||||
- user_pass_secret (folder)
|
||||
- username (file), content: myuser
|
||||
- password (file), content: mypassword
|
||||
- my_plain_secret (folder)
|
||||
- regular_secret (file), content: mysecret
|
||||
|
||||
See: https://kubernetes.io/docs/tasks/inject-data-application/distribute-credentials-secure/#create-a-pod-that-has-access-to-the-secret-data-through-a-volume
|
||||
And: https://docs.snowflake.com/en/developer-guide/snowpark-container-services/additional-considerations-services-jobs#passing-secrets-in-local-container-files
|
||||
"""
|
||||
secrets: dict[str, Any] = {}
|
||||
found_secrets_file = False
|
||||
|
||||
for dirname in os.listdir(path):
|
||||
sub_folder_path = os.path.join(path, dirname)
|
||||
if not os.path.isdir(sub_folder_path):
|
||||
error_msg = secret_error_messages_singleton.get_subfolder_path_is_not_a_folder_message(
|
||||
sub_folder_path
|
||||
)
|
||||
raise StreamlitSecretNotFoundError(error_msg)
|
||||
sub_secrets = {}
|
||||
|
||||
for filename in os.listdir(sub_folder_path):
|
||||
file_path = os.path.join(sub_folder_path, filename)
|
||||
|
||||
# ignore folders
|
||||
if os.path.isdir(file_path):
|
||||
continue
|
||||
|
||||
with open(file_path) as f:
|
||||
sub_secrets[filename] = f.read().strip()
|
||||
found_secrets_file = True
|
||||
|
||||
if len(sub_secrets) == 1:
|
||||
# if there's just one file, collapse it so it's directly under `dirname`
|
||||
secrets[dirname] = sub_secrets[list(sub_secrets.keys())[0]]
|
||||
else:
|
||||
secrets[dirname] = sub_secrets
|
||||
|
||||
return secrets, found_secrets_file
|
||||
|
||||
def _parse_file_path(self, path: str) -> tuple[Mapping[str, Any], bool]:
|
||||
if path.endswith(".toml"):
|
||||
return self._parse_toml_file(path)
|
||||
|
||||
if os.path.isdir(path):
|
||||
return self._parse_directory(path)
|
||||
|
||||
error_msg = secret_error_messages_singleton.get_invalid_secret_path_message(
|
||||
path
|
||||
)
|
||||
raise StreamlitSecretNotFoundError(error_msg)
|
||||
|
||||
def _parse(self) -> Mapping[str, Any]:
|
||||
"""Parse our secrets.toml files if they're not already parsed.
|
||||
This function is safe to call from multiple threads.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
print_exceptions : bool
|
||||
If True, then exceptions will be printed with `st.error` before
|
||||
being re-raised.
|
||||
|
||||
Raises
|
||||
------
|
||||
StreamlitSecretNotFoundError
|
||||
Raised if secrets.toml doesn't exist.
|
||||
|
||||
"""
|
||||
# Avoid taking a lock for the common case where secrets are already
|
||||
# loaded.
|
||||
secrets = self._secrets
|
||||
if secrets is not None:
|
||||
return secrets
|
||||
|
||||
with self._lock:
|
||||
if self._secrets is not None:
|
||||
return self._secrets
|
||||
|
||||
secrets = {}
|
||||
|
||||
file_paths = st.config.get_option("secrets.files")
|
||||
found_secrets_file = False
|
||||
for path in file_paths:
|
||||
path_secrets, found_secrets_file_in_path = self._parse_file_path(path)
|
||||
found_secrets_file = found_secrets_file or found_secrets_file_in_path
|
||||
secrets.update(path_secrets)
|
||||
|
||||
if not found_secrets_file:
|
||||
error_msg = (
|
||||
secret_error_messages_singleton.get_no_secrets_found_message(
|
||||
file_paths
|
||||
)
|
||||
)
|
||||
raise StreamlitSecretNotFoundError(error_msg)
|
||||
|
||||
for k, v in secrets.items():
|
||||
self._maybe_set_environment_variable(k, v)
|
||||
|
||||
self._secrets = secrets
|
||||
self._maybe_install_file_watchers()
|
||||
|
||||
return self._secrets
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Converts the secrets store into a nested dictionary, where nested AttrDict objects are also converted into dictionaries."""
|
||||
secrets = self._parse()
|
||||
return _convert_to_dict(secrets)
|
||||
|
||||
@staticmethod
|
||||
def _maybe_set_environment_variable(k: Any, v: Any) -> None:
|
||||
"""Add the given key/value pair to os.environ if the value
|
||||
is a string, int, or float.
|
||||
"""
|
||||
value_type = type(v)
|
||||
if value_type in (str, int, float):
|
||||
os.environ[k] = str(v)
|
||||
|
||||
@staticmethod
|
||||
def _maybe_delete_environment_variable(k: Any, v: Any) -> None:
|
||||
"""Remove the given key/value pair from os.environ if the value
|
||||
is a string, int, or float.
|
||||
"""
|
||||
value_type = type(v)
|
||||
if value_type in (str, int, float) and os.environ.get(k) == v:
|
||||
del os.environ[k]
|
||||
|
||||
def _maybe_install_file_watchers(self) -> None:
|
||||
with self._lock:
|
||||
if self._file_watchers_installed:
|
||||
return
|
||||
|
||||
file_paths = st.config.get_option("secrets.files")
|
||||
for path in file_paths:
|
||||
try:
|
||||
if path.endswith(".toml"):
|
||||
streamlit.watcher.path_watcher.watch_file(
|
||||
path,
|
||||
self._on_secrets_changed,
|
||||
watcher_type="poll",
|
||||
)
|
||||
else:
|
||||
streamlit.watcher.path_watcher.watch_dir(
|
||||
path,
|
||||
self._on_secrets_changed,
|
||||
watcher_type="poll",
|
||||
)
|
||||
except FileNotFoundError:
|
||||
# A user may only have one secrets.toml file defined, so we'd expect
|
||||
# FileNotFoundErrors to be raised when attempting to install a
|
||||
# watcher on the nonexistent ones.
|
||||
pass
|
||||
|
||||
# We set file_watchers_installed to True even if the installation attempt
|
||||
# failed to avoid repeatedly trying to install it.
|
||||
self._file_watchers_installed = True
|
||||
|
||||
def _on_secrets_changed(self, changed_file_path) -> None:
|
||||
with self._lock:
|
||||
_LOGGER.debug("Secret path %s changed, reloading", changed_file_path)
|
||||
self._reset()
|
||||
self._parse()
|
||||
|
||||
# Emit a signal to notify receivers that the `secrets.toml` file
|
||||
# has been changed.
|
||||
self.file_change_listener.send()
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
"""Return the value with the given key. If no such key
|
||||
exists, raise an AttributeError.
|
||||
|
||||
Thread-safe.
|
||||
"""
|
||||
try:
|
||||
value = self._parse()[key]
|
||||
if not isinstance(value, Mapping):
|
||||
return value
|
||||
else:
|
||||
return AttrDict(value)
|
||||
# We add FileNotFoundError since __getattr__ is expected to only raise
|
||||
# AttributeError. Without handling FileNotFoundError, unittests.mocks
|
||||
# fails during mock creation on Python3.9
|
||||
except (KeyError, FileNotFoundError):
|
||||
raise AttributeError(_missing_attr_error_message(key))
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""Return the value with the given key. If no such key
|
||||
exists, raise a KeyError.
|
||||
|
||||
Thread-safe.
|
||||
"""
|
||||
try:
|
||||
value = self._parse()[key]
|
||||
if not isinstance(value, Mapping):
|
||||
return value
|
||||
else:
|
||||
return AttrDict(value)
|
||||
except KeyError:
|
||||
raise KeyError(_missing_key_error_message(key))
|
||||
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
# Allow internal attributes to be set
|
||||
if key in {
|
||||
"_secrets",
|
||||
"_lock",
|
||||
"_file_watchers_installed",
|
||||
"_suppress_print_error_on_exception",
|
||||
"file_change_listener",
|
||||
"load_if_toml_exists",
|
||||
}:
|
||||
super().__setattr__(key, value)
|
||||
else:
|
||||
raise TypeError("Secrets does not support attribute assignment.")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# If the runtime is NOT initialized, it is a method call outside
|
||||
# the streamlit app, so we avoid reading the secrets file as it may not exist.
|
||||
# If the runtime is initialized, display the contents of the file and
|
||||
# the file must already exist.
|
||||
"""A string representation of the contents of the dict. Thread-safe."""
|
||||
if not runtime.exists():
|
||||
return f"{self.__class__.__name__}"
|
||||
return repr(self._parse())
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""The number of entries in the dict. Thread-safe."""
|
||||
return len(self._parse())
|
||||
|
||||
def has_key(self, k: str) -> bool:
|
||||
"""True if the given key is in the dict. Thread-safe."""
|
||||
return k in self._parse()
|
||||
|
||||
def keys(self) -> KeysView[str]:
|
||||
"""A view of the keys in the dict. Thread-safe."""
|
||||
return self._parse().keys()
|
||||
|
||||
def values(self) -> ValuesView[Any]:
|
||||
"""A view of the values in the dict. Thread-safe."""
|
||||
return self._parse().values()
|
||||
|
||||
def items(self) -> ItemsView[str, Any]:
|
||||
"""A view of the key-value items in the dict. Thread-safe."""
|
||||
return self._parse().items()
|
||||
|
||||
def __contains__(self, key: Any) -> bool:
|
||||
"""True if the given key is in the dict. Thread-safe."""
|
||||
return key in self._parse()
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
"""An iterator over the keys in the dict. Thread-safe."""
|
||||
return iter(self._parse())
|
||||
|
||||
|
||||
secrets_singleton: Final = Secrets()
|
||||
@@ -0,0 +1,394 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, Protocol, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
from streamlit.runtime.app_session import AppSession
|
||||
from streamlit.runtime.script_data import ScriptData
|
||||
from streamlit.runtime.scriptrunner.script_cache import ScriptCache
|
||||
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
|
||||
|
||||
|
||||
class SessionClientDisconnectedError(Exception):
|
||||
"""Raised by operations on a disconnected SessionClient."""
|
||||
|
||||
|
||||
class SessionClient(Protocol):
|
||||
"""Interface for sending data to a session's client."""
|
||||
|
||||
@abstractmethod
|
||||
def write_forward_msg(self, msg: ForwardMsg) -> None:
|
||||
"""Deliver a ForwardMsg to the client.
|
||||
|
||||
If the SessionClient has been disconnected, it should raise a
|
||||
SessionClientDisconnectedError.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveSessionInfo:
|
||||
"""Type containing data related to an active session.
|
||||
|
||||
This type is nearly identical to SessionInfo. The difference is that when using it,
|
||||
we are guaranteed that SessionClient is not None.
|
||||
"""
|
||||
|
||||
client: SessionClient
|
||||
session: AppSession
|
||||
script_run_count: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionInfo:
|
||||
"""Type containing data related to an AppSession.
|
||||
|
||||
For each AppSession, the Runtime tracks that session's
|
||||
script_run_count. This is used to track the age of messages in
|
||||
the ForwardMsgCache.
|
||||
"""
|
||||
|
||||
client: SessionClient | None
|
||||
session: AppSession
|
||||
script_run_count: int = 0
|
||||
|
||||
def is_active(self) -> bool:
|
||||
return self.client is not None
|
||||
|
||||
def to_active(self) -> ActiveSessionInfo:
|
||||
assert self.is_active(), "A SessionInfo with no client cannot be active!"
|
||||
|
||||
# NOTE: The cast here (rather than copying this SessionInfo's fields into a new
|
||||
# ActiveSessionInfo) is important as the Runtime expects to be able to mutate
|
||||
# what's returned from get_active_session_info to increment script_run_count.
|
||||
return cast("ActiveSessionInfo", self)
|
||||
|
||||
|
||||
class SessionStorageError(Exception):
|
||||
"""Exception class for errors raised by SessionStorage.
|
||||
|
||||
The original error that causes a SessionStorageError to be (re)raised will generally
|
||||
be an I/O error specific to the concrete SessionStorage implementation.
|
||||
"""
|
||||
|
||||
|
||||
class SessionStorage(Protocol):
|
||||
@abstractmethod
|
||||
def get(self, session_id: str) -> SessionInfo | None:
|
||||
"""Return the SessionInfo corresponding to session_id, or None if one does not
|
||||
exist.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The unique ID of the session being fetched.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SessionInfo or None
|
||||
|
||||
Raises
|
||||
------
|
||||
SessionStorageError
|
||||
Raised if an error occurs while attempting to fetch the session. This will
|
||||
generally happen if there is an error with the underlying storage backend
|
||||
(e.g. if we lose our connection to it).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def save(self, session_info: SessionInfo) -> None:
|
||||
"""Save the given session.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_info
|
||||
The SessionInfo being saved.
|
||||
|
||||
Raises
|
||||
------
|
||||
SessionStorageError
|
||||
Raised if an error occurs while saving the given session.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, session_id: str) -> None:
|
||||
"""Mark the session corresponding to session_id for deletion and stop tracking
|
||||
it.
|
||||
|
||||
Note that:
|
||||
- Calling delete on an ID corresponding to a nonexistent session is a no-op.
|
||||
- Calling delete on an ID should cause the given session to no longer be
|
||||
tracked by this SessionStorage, but exactly when and how the session's data
|
||||
is eventually cleaned up is a detail left up to the implementation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The unique ID of the session to delete.
|
||||
|
||||
Raises
|
||||
------
|
||||
SessionStorageError
|
||||
Raised if an error occurs while attempting to delete the session.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def list(self) -> list[SessionInfo]:
|
||||
"""List all sessions tracked by this SessionStorage.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[SessionInfo]
|
||||
|
||||
Raises
|
||||
------
|
||||
SessionStorageError
|
||||
Raised if an error occurs while attempting to list sessions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SessionManager(Protocol):
|
||||
"""SessionManagers are responsible for encapsulating all session lifecycle behavior
|
||||
that the Streamlit Runtime may care about.
|
||||
|
||||
A SessionManager must define the following required methods:
|
||||
- __init__
|
||||
- connect_session
|
||||
- close_session
|
||||
- get_session_info
|
||||
- list_sessions
|
||||
|
||||
SessionManager implementations may also choose to define the notions of active and
|
||||
inactive sessions. The precise definitions of active/inactive are left to the
|
||||
concrete implementation. SessionManagers that wish to differentiate between active
|
||||
and inactive sessions should have the required methods listed above operate on *all*
|
||||
sessions. Additionally, they should define the following methods for working with
|
||||
active sessions:
|
||||
- disconnect_session
|
||||
- get_active_session_info
|
||||
- is_active_session
|
||||
- list_active_sessions
|
||||
|
||||
When active session-related methods are left undefined, their default
|
||||
implementations are the naturally corresponding required methods.
|
||||
|
||||
The Runtime, unless there's a good reason to do otherwise, should generally work
|
||||
with the active-session versions of a SessionManager's methods. There isn't currently
|
||||
a need for us to be able to operate on inactive sessions stored in SessionStorage
|
||||
outside of the SessionManager itself. However, it's highly likely that we'll
|
||||
eventually have to do so, which is why the abstractions allow for this now.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: All SessionManager methods are *not* threadsafe -- they must be called
|
||||
from the runtime's eventloop thread.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
session_storage: SessionStorage,
|
||||
uploaded_file_manager: UploadedFileManager,
|
||||
script_cache: ScriptCache,
|
||||
message_enqueued_callback: Callable[[], None] | None,
|
||||
) -> None:
|
||||
"""Initialize a SessionManager with the given SessionStorage.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_storage
|
||||
The SessionStorage instance backing this SessionManager.
|
||||
|
||||
uploaded_file_manager
|
||||
Used to manage files uploaded by users via the Streamlit web client.
|
||||
|
||||
script_cache
|
||||
ScriptCache instance. Caches user script bytecode.
|
||||
|
||||
message_enqueued_callback
|
||||
A callback invoked after a message is enqueued to be sent to a web client.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def connect_session(
|
||||
self,
|
||||
client: SessionClient,
|
||||
script_data: ScriptData,
|
||||
user_info: dict[str, str | bool | None],
|
||||
existing_session_id: str | None = None,
|
||||
session_id_override: str | None = None,
|
||||
) -> str:
|
||||
"""Create a new session or connect to an existing one.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client
|
||||
A concrete SessionClient implementation for communicating with
|
||||
the session's client.
|
||||
script_data
|
||||
Contains parameters related to running a script.
|
||||
user_info
|
||||
A dict that contains information about the session's user. For now,
|
||||
it only (optionally) contains the user's email address.
|
||||
|
||||
{
|
||||
"email": "example@example.com"
|
||||
}
|
||||
existing_session_id
|
||||
The ID of an existing session to reconnect to. If one is not provided, a new
|
||||
session is created. Note that whether a SessionManager supports reconnecting
|
||||
to an existing session is left up to the concrete SessionManager
|
||||
implementation. Those that do not support reconnection should simply ignore
|
||||
this argument.
|
||||
session_id_override
|
||||
The ID to assign to a new session being created with this method. Setting
|
||||
this can be useful when the service that a Streamlit Runtime is running in
|
||||
wants to tie the lifecycle of a Streamlit session to some other session-like
|
||||
object that it manages. Only one of existing_session_id and
|
||||
session_id_override should be set.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The session's unique string ID.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def close_session(self, session_id: str) -> None:
|
||||
"""Close and completely delete the session with the given id.
|
||||
|
||||
This function may be called multiple times for the same session,
|
||||
which is not an error. (Subsequent calls just no-op.)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The session's unique ID.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_session_info(self, session_id: str) -> SessionInfo | None:
|
||||
"""Return the SessionInfo for the given id, or None if no such session
|
||||
exists.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The session's unique ID.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SessionInfo or None
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def list_sessions(self) -> list[SessionInfo]:
|
||||
"""Return the SessionInfo for all sessions managed by this SessionManager.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[SessionInfo]
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def num_sessions(self) -> int:
|
||||
"""Return the number of sessions tracked by this SessionManager.
|
||||
|
||||
Subclasses of SessionManager shouldn't provide their own implementation of this
|
||||
method without a *very* good reason.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
"""
|
||||
return len(self.list_sessions())
|
||||
|
||||
# NOTE: The following methods only need to be overwritten when a concrete
|
||||
# SessionManager implementation has a notion of active vs inactive sessions.
|
||||
# If left unimplemented in a subclass, the default implementations of these methods
|
||||
# call corresponding SessionManager methods in a natural way.
|
||||
|
||||
def disconnect_session(self, session_id: str) -> None:
|
||||
"""Disconnect the given session.
|
||||
|
||||
This method should be idempotent.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The session's unique ID.
|
||||
"""
|
||||
self.close_session(session_id)
|
||||
|
||||
def get_active_session_info(self, session_id: str) -> ActiveSessionInfo | None:
|
||||
"""Return the ActiveSessionInfo for the given id, or None if either no such
|
||||
session exists or the session is not active.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The active session's unique ID.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ActiveSessionInfo or None
|
||||
"""
|
||||
session = self.get_session_info(session_id)
|
||||
if session is None or not session.is_active():
|
||||
return None
|
||||
return session.to_active()
|
||||
|
||||
def is_active_session(self, session_id: str) -> bool:
|
||||
"""Return True if the given session exists and is active, False otherwise.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
return self.get_active_session_info(session_id) is not None
|
||||
|
||||
def list_active_sessions(self) -> list[ActiveSessionInfo]:
|
||||
"""Return the session info for all active sessions tracked by this SessionManager.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[ActiveSessionInfo]
|
||||
"""
|
||||
return [s.to_active() for s in self.list_sessions()]
|
||||
|
||||
def num_active_sessions(self) -> int:
|
||||
"""Return the number of active sessions tracked by this SessionManager.
|
||||
|
||||
Subclasses of SessionManager shouldn't provide their own implementation of this
|
||||
method without a *very* good reason.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
"""
|
||||
return len(self.list_active_sessions())
|
||||
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from streamlit.runtime.state.common import WidgetArgs, WidgetCallback, WidgetKwargs
|
||||
from streamlit.runtime.state.query_params_proxy import QueryParamsProxy
|
||||
from streamlit.runtime.state.safe_session_state import SafeSessionState
|
||||
from streamlit.runtime.state.session_state import (
|
||||
SCRIPT_RUN_WITHOUT_ERRORS_KEY,
|
||||
SessionState,
|
||||
SessionStateStatProvider,
|
||||
)
|
||||
from streamlit.runtime.state.session_state_proxy import (
|
||||
SessionStateProxy,
|
||||
get_session_state,
|
||||
)
|
||||
from streamlit.runtime.state.widgets import register_widget
|
||||
|
||||
__all__ = [
|
||||
"WidgetArgs",
|
||||
"WidgetCallback",
|
||||
"WidgetKwargs",
|
||||
"QueryParamsProxy",
|
||||
"SafeSessionState",
|
||||
"SCRIPT_RUN_WITHOUT_ERRORS_KEY",
|
||||
"SessionState",
|
||||
"SessionStateStatProvider",
|
||||
"SessionStateProxy",
|
||||
"get_session_state",
|
||||
"register_widget",
|
||||
]
|
||||
@@ -0,0 +1,191 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Functions and data structures shared by session_state.py and widgets.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Final,
|
||||
Generic,
|
||||
Literal,
|
||||
TypeVar,
|
||||
cast,
|
||||
get_args,
|
||||
)
|
||||
|
||||
from typing_extensions import TypeAlias, TypeGuard
|
||||
|
||||
from streamlit import util
|
||||
from streamlit.errors import (
|
||||
StreamlitAPIException,
|
||||
)
|
||||
|
||||
GENERATED_ELEMENT_ID_PREFIX: Final = "$$ID"
|
||||
TESTING_KEY = "$$STREAMLIT_INTERNAL_KEY_TESTING"
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
|
||||
|
||||
WidgetArgs: TypeAlias = tuple[Any, ...]
|
||||
WidgetKwargs: TypeAlias = dict[str, Any]
|
||||
WidgetCallback: TypeAlias = Callable[..., None]
|
||||
|
||||
# A deserializer receives the value from whatever field is set on the
|
||||
# WidgetState proto, and returns a regular python value. A serializer
|
||||
# receives a regular python value, and returns something suitable for
|
||||
# a value field on WidgetState proto. They should be inverses.
|
||||
WidgetDeserializer: TypeAlias = Callable[[Any, str], T]
|
||||
WidgetSerializer: TypeAlias = Callable[[T], Any]
|
||||
|
||||
# The array value field names are part of the larger set of possible value
|
||||
# field names. See the explanation for said set below. The message types
|
||||
# associated with these fields are distinguished by storing data in a `data`
|
||||
# field in their messages, meaning they need special treatment in certain
|
||||
# circumstances. Hence, they need their own, dedicated, sub-type.
|
||||
ArrayValueFieldName: TypeAlias = Literal[
|
||||
"double_array_value",
|
||||
"int_array_value",
|
||||
"string_array_value",
|
||||
]
|
||||
|
||||
# A frozenset containing the allowed values of the ArrayValueFieldName type.
|
||||
# Useful for membership checking.
|
||||
_ARRAY_VALUE_FIELD_NAMES: Final = frozenset(
|
||||
cast(
|
||||
"tuple[ArrayValueFieldName, ...]",
|
||||
# NOTE: get_args is not recursive, so this only works as long as
|
||||
# ArrayValueFieldName remains flat.
|
||||
get_args(ArrayValueFieldName),
|
||||
)
|
||||
)
|
||||
|
||||
# These are the possible field names that can be set in the `value` oneof-field
|
||||
# of the WidgetState message (schema found in .proto/WidgetStates.proto).
|
||||
# We need these as a literal type to ensure correspondence with the protobuf
|
||||
# schema in certain parts of the python code.
|
||||
# TODO(harahu): It would be preferable if this type was automatically derived
|
||||
# from the protobuf schema, rather than manually maintained. Not sure how to
|
||||
# achieve that, though.
|
||||
ValueFieldName: TypeAlias = Literal[
|
||||
ArrayValueFieldName,
|
||||
"arrow_value",
|
||||
"bool_value",
|
||||
"bytes_value",
|
||||
"double_value",
|
||||
"file_uploader_state_value",
|
||||
"int_value",
|
||||
"json_value",
|
||||
"string_value",
|
||||
"trigger_value",
|
||||
"string_trigger_value",
|
||||
"chat_input_value",
|
||||
]
|
||||
|
||||
|
||||
def is_array_value_field_name(obj: object) -> TypeGuard[ArrayValueFieldName]:
|
||||
return obj in _ARRAY_VALUE_FIELD_NAMES
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WidgetMetadata(Generic[T]):
|
||||
"""Metadata associated with a single widget. Immutable."""
|
||||
|
||||
id: str
|
||||
deserializer: WidgetDeserializer[T] = field(repr=False)
|
||||
serializer: WidgetSerializer[T] = field(repr=False)
|
||||
value_type: ValueFieldName
|
||||
|
||||
# An optional user-code callback invoked when the widget's value changes.
|
||||
# Widget callbacks are called at the start of a script run, before the
|
||||
# body of the script is executed.
|
||||
callback: WidgetCallback | None = None
|
||||
callback_args: WidgetArgs | None = None
|
||||
callback_kwargs: WidgetKwargs | None = None
|
||||
|
||||
fragment_id: str | None = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegisterWidgetResult(Generic[T_co]):
|
||||
"""Result returned by the `register_widget` family of functions/methods.
|
||||
|
||||
Should be usable by widget code to determine what value to return, and
|
||||
whether to update the UI.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value : T_co
|
||||
The widget's current value, or, in cases where the true widget value
|
||||
could not be determined, an appropriate fallback value.
|
||||
|
||||
This value should be returned by the widget call.
|
||||
value_changed : bool
|
||||
True if the widget's value is different from the value most recently
|
||||
returned from the frontend.
|
||||
|
||||
Implies an update to the frontend is needed.
|
||||
"""
|
||||
|
||||
value: T_co
|
||||
value_changed: bool
|
||||
|
||||
@classmethod
|
||||
def failure(
|
||||
cls, deserializer: WidgetDeserializer[T_co]
|
||||
) -> RegisterWidgetResult[T_co]:
|
||||
"""The canonical way to construct a RegisterWidgetResult in cases
|
||||
where the true widget value could not be determined.
|
||||
"""
|
||||
return cls(value=deserializer(None, ""), value_changed=False)
|
||||
|
||||
|
||||
def user_key_from_element_id(element_id: str) -> str | None:
|
||||
"""Return the user key portion of a element id, or None if the id does not
|
||||
have a user key.
|
||||
|
||||
TODO This will incorrectly indicate no user key if the user actually provides
|
||||
"None" as a key, but we can't avoid this kind of problem while storing the
|
||||
string representation of the no-user-key sentinel as part of the element id.
|
||||
"""
|
||||
user_key: str | None = element_id.split("-", maxsplit=2)[-1]
|
||||
return None if user_key == "None" else user_key
|
||||
|
||||
|
||||
def is_element_id(key: str) -> bool:
|
||||
"""True if the given session_state key has the structure of a element ID."""
|
||||
return key.startswith(GENERATED_ELEMENT_ID_PREFIX)
|
||||
|
||||
|
||||
def is_keyed_element_id(key: str) -> bool:
|
||||
"""True if the given session_state key has the structure of a element ID
|
||||
with a user_key.
|
||||
"""
|
||||
return is_element_id(key) and not key.endswith("-None")
|
||||
|
||||
|
||||
def require_valid_user_key(key: str) -> None:
|
||||
"""Raise an Exception if the given user_key is invalid."""
|
||||
if is_element_id(key):
|
||||
raise StreamlitAPIException(
|
||||
f"Keys beginning with {GENERATED_ELEMENT_ID_PREFIX} are reserved."
|
||||
)
|
||||
@@ -0,0 +1,205 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Iterator, MutableMapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Final
|
||||
from urllib import parse
|
||||
|
||||
from streamlit.errors import StreamlitAPIException
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import SupportsKeysAndGetItem
|
||||
|
||||
|
||||
EMBED_QUERY_PARAM: Final[str] = "embed"
|
||||
EMBED_OPTIONS_QUERY_PARAM: Final[str] = "embed_options"
|
||||
EMBED_QUERY_PARAMS_KEYS: Final[list[str]] = [
|
||||
EMBED_QUERY_PARAM,
|
||||
EMBED_OPTIONS_QUERY_PARAM,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryParams(MutableMapping[str, str]):
|
||||
"""A lightweight wrapper of a dict that sends forwardMsgs when state changes.
|
||||
It stores str keys with str and List[str] values.
|
||||
"""
|
||||
|
||||
_query_params: dict[str, list[str] | str] = field(default_factory=dict)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
self._ensure_single_query_api_used()
|
||||
|
||||
return iter(
|
||||
key
|
||||
for key in self._query_params.keys()
|
||||
if key not in EMBED_QUERY_PARAMS_KEYS
|
||||
)
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
"""Retrieves a value for a given key in query parameters.
|
||||
Returns the last item in a list or an empty string if empty.
|
||||
If the key is not present, raise KeyError.
|
||||
"""
|
||||
self._ensure_single_query_api_used()
|
||||
try:
|
||||
if key in EMBED_QUERY_PARAMS_KEYS:
|
||||
raise KeyError(missing_key_error_message(key))
|
||||
value = self._query_params[key]
|
||||
if isinstance(value, list):
|
||||
if len(value) == 0:
|
||||
return ""
|
||||
else:
|
||||
# Return the last value to mimic Tornado's behavior
|
||||
# https://www.tornadoweb.org/en/stable/web.html#tornado.web.RequestHandler.get_query_argument
|
||||
return value[-1]
|
||||
return value
|
||||
except KeyError:
|
||||
raise KeyError(missing_key_error_message(key))
|
||||
|
||||
def __setitem__(self, key: str, value: str | Iterable[str]) -> None:
|
||||
self._ensure_single_query_api_used()
|
||||
self.__set_item_internal(key, value)
|
||||
self._send_query_param_msg()
|
||||
|
||||
def __set_item_internal(self, key: str, value: str | Iterable[str]) -> None:
|
||||
if isinstance(value, dict):
|
||||
raise StreamlitAPIException(
|
||||
f"You cannot set a query params key `{key}` to a dictionary."
|
||||
)
|
||||
|
||||
if key in EMBED_QUERY_PARAMS_KEYS:
|
||||
raise StreamlitAPIException(
|
||||
"Query param embed and embed_options (case-insensitive) cannot be set programmatically."
|
||||
)
|
||||
# Type checking users should handle the string serialization themselves
|
||||
# We will accept any type for the list and serialize to str just in case
|
||||
if isinstance(value, Iterable) and not isinstance(value, str):
|
||||
self._query_params[key] = [str(item) for item in value]
|
||||
else:
|
||||
self._query_params[key] = str(value)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
self._ensure_single_query_api_used()
|
||||
try:
|
||||
if key in EMBED_QUERY_PARAMS_KEYS:
|
||||
raise KeyError(missing_key_error_message(key))
|
||||
del self._query_params[key]
|
||||
self._send_query_param_msg()
|
||||
except KeyError:
|
||||
raise KeyError(missing_key_error_message(key))
|
||||
|
||||
def update(
|
||||
self,
|
||||
other: Iterable[tuple[str, str | Iterable[str]]]
|
||||
| SupportsKeysAndGetItem[str, str | Iterable[str]] = (),
|
||||
/,
|
||||
**kwds: str,
|
||||
):
|
||||
# This overrides the `update` provided by MutableMapping
|
||||
# to ensure only one one ForwardMsg is sent.
|
||||
self._ensure_single_query_api_used()
|
||||
if hasattr(other, "keys") and hasattr(other, "__getitem__"):
|
||||
for key in other.keys():
|
||||
self.__set_item_internal(key, other[key])
|
||||
else:
|
||||
for key, value in other:
|
||||
self.__set_item_internal(key, value)
|
||||
for key, value in kwds.items():
|
||||
self.__set_item_internal(key, value)
|
||||
self._send_query_param_msg()
|
||||
|
||||
def get_all(self, key: str) -> list[str]:
|
||||
self._ensure_single_query_api_used()
|
||||
if key not in self._query_params or key in EMBED_QUERY_PARAMS_KEYS:
|
||||
return []
|
||||
value = self._query_params[key]
|
||||
return value if isinstance(value, list) else [value]
|
||||
|
||||
def __len__(self) -> int:
|
||||
self._ensure_single_query_api_used()
|
||||
return len(
|
||||
{key for key in self._query_params if key not in EMBED_QUERY_PARAMS_KEYS}
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
self._ensure_single_query_api_used()
|
||||
return str(self._query_params)
|
||||
|
||||
def _send_query_param_msg(self) -> None:
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is None:
|
||||
return
|
||||
self._ensure_single_query_api_used()
|
||||
|
||||
msg = ForwardMsg()
|
||||
msg.page_info_changed.query_string = parse.urlencode(
|
||||
self._query_params, doseq=True
|
||||
)
|
||||
ctx.query_string = msg.page_info_changed.query_string
|
||||
ctx.enqueue(msg)
|
||||
|
||||
def clear(self) -> None:
|
||||
self._ensure_single_query_api_used()
|
||||
self.clear_with_no_forward_msg(preserve_embed=True)
|
||||
self._send_query_param_msg()
|
||||
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
self._ensure_single_query_api_used()
|
||||
# return the last query param if multiple values are set
|
||||
return {
|
||||
key: self[key]
|
||||
for key in self._query_params
|
||||
if key not in EMBED_QUERY_PARAMS_KEYS
|
||||
}
|
||||
|
||||
def from_dict(
|
||||
self,
|
||||
_dict: Iterable[tuple[str, str | Iterable[str]]]
|
||||
| SupportsKeysAndGetItem[str, str | Iterable[str]],
|
||||
):
|
||||
self._ensure_single_query_api_used()
|
||||
old_value = self._query_params.copy()
|
||||
self.clear_with_no_forward_msg(preserve_embed=True)
|
||||
try:
|
||||
self.update(_dict)
|
||||
except StreamlitAPIException:
|
||||
# restore the original from before we made any changes.
|
||||
self._query_params = old_value
|
||||
raise
|
||||
|
||||
def set_with_no_forward_msg(self, key: str, val: list[str] | str) -> None:
|
||||
self._query_params[key] = val
|
||||
|
||||
def clear_with_no_forward_msg(self, preserve_embed: bool = False) -> None:
|
||||
self._query_params = {
|
||||
key: value
|
||||
for key, value in self._query_params.items()
|
||||
if key in EMBED_QUERY_PARAMS_KEYS and preserve_embed
|
||||
}
|
||||
|
||||
def _ensure_single_query_api_used(self):
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is None:
|
||||
return
|
||||
ctx.mark_production_query_params_used()
|
||||
|
||||
|
||||
def missing_key_error_message(key: str) -> str:
|
||||
return f'st.query_params has no key "{key}".'
|
||||
@@ -0,0 +1,218 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Iterator, MutableMapping
|
||||
from typing import TYPE_CHECKING, overload
|
||||
|
||||
from streamlit.runtime.metrics_util import gather_metrics
|
||||
from streamlit.runtime.state.session_state_proxy import get_session_state
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import SupportsKeysAndGetItem
|
||||
|
||||
|
||||
class QueryParamsProxy(MutableMapping[str, str]):
|
||||
"""
|
||||
A stateless singleton that proxies ``st.query_params`` interactions
|
||||
to the current script thread's QueryParams instance.
|
||||
"""
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
with get_session_state().query_params() as qp:
|
||||
return iter(qp)
|
||||
|
||||
def __len__(self) -> int:
|
||||
with get_session_state().query_params() as qp:
|
||||
return len(qp)
|
||||
|
||||
def __str__(self) -> str:
|
||||
with get_session_state().query_params() as qp:
|
||||
return str(qp)
|
||||
|
||||
@gather_metrics("query_params.get_item")
|
||||
def __getitem__(self, key: str) -> str:
|
||||
with get_session_state().query_params() as qp:
|
||||
try:
|
||||
return qp[key]
|
||||
except KeyError:
|
||||
raise KeyError(self.missing_key_error_message(key))
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
with get_session_state().query_params() as qp:
|
||||
del qp[key]
|
||||
|
||||
@gather_metrics("query_params.set_item")
|
||||
def __setitem__(self, key: str, value: str | Iterable[str]) -> None:
|
||||
with get_session_state().query_params() as qp:
|
||||
qp[key] = value
|
||||
|
||||
@gather_metrics("query_params.get_attr")
|
||||
def __getattr__(self, key: str) -> str:
|
||||
with get_session_state().query_params() as qp:
|
||||
try:
|
||||
return qp[key]
|
||||
except KeyError:
|
||||
raise AttributeError(self.missing_attr_error_message(key))
|
||||
|
||||
def __delattr__(self, key: str) -> None:
|
||||
with get_session_state().query_params() as qp:
|
||||
try:
|
||||
del qp[key]
|
||||
except KeyError:
|
||||
raise AttributeError(self.missing_key_error_message(key))
|
||||
|
||||
@overload
|
||||
def update(
|
||||
self, mapping: SupportsKeysAndGetItem[str, str | Iterable[str]], /, **kwds: str
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def update(
|
||||
self, keys_and_values: Iterable[tuple[str, str | Iterable[str]]], /, **kwds: str
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def update(self, **kwds: str | Iterable[str]) -> None: ...
|
||||
|
||||
def update(self, other=(), /, **kwds):
|
||||
"""
|
||||
Update one or more values in query_params at once from a dictionary or
|
||||
dictionary-like object.
|
||||
|
||||
See `Mapping.update()` from Python's `collections` library.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
other: SupportsKeysAndGetItem[str, str] | Iterable[tuple[str, str]]
|
||||
A dictionary or mapping of strings to strings.
|
||||
**kwds: str
|
||||
Additional key/value pairs to update passed as keyword arguments.
|
||||
"""
|
||||
with get_session_state().query_params() as qp:
|
||||
qp.update(other, **kwds)
|
||||
|
||||
@gather_metrics("query_params.set_attr")
|
||||
def __setattr__(self, key: str, value: str | Iterable[str]) -> None:
|
||||
with get_session_state().query_params() as qp:
|
||||
qp[key] = value
|
||||
|
||||
@gather_metrics("query_params.get_all")
|
||||
def get_all(self, key: str) -> list[str]:
|
||||
"""
|
||||
Get a list of all query parameter values associated to a given key.
|
||||
|
||||
When a key is repeated as a query parameter within the URL, this method
|
||||
allows all values to be obtained. In contrast, dict-like methods only
|
||||
retrieve the last value when a key is repeated in the URL.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key: str
|
||||
The label of the query parameter in the URL.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[str]
|
||||
A list of values associated to the given key. May return zero, one,
|
||||
or multiple values.
|
||||
"""
|
||||
with get_session_state().query_params() as qp:
|
||||
return qp.get_all(key)
|
||||
|
||||
@gather_metrics("query_params.clear")
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all query parameters from the URL of the app.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
with get_session_state().query_params() as qp:
|
||||
qp.clear()
|
||||
|
||||
@gather_metrics("query_params.to_dict")
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
"""
|
||||
Get all query parameters as a dictionary.
|
||||
|
||||
This method primarily exists for internal use and is not needed for
|
||||
most cases. ``st.query_params`` returns an object that inherits from
|
||||
``dict`` by default.
|
||||
|
||||
When a key is repeated as a query parameter within the URL, this method
|
||||
will return only the last value of each unique key.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str,str]
|
||||
A dictionary of the current query paramters in the app's URL.
|
||||
"""
|
||||
with get_session_state().query_params() as qp:
|
||||
return qp.to_dict()
|
||||
|
||||
@overload
|
||||
def from_dict(
|
||||
self, keys_and_values: Iterable[tuple[str, str | Iterable[str]]]
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def from_dict(
|
||||
self, mapping: SupportsKeysAndGetItem[str, str | Iterable[str]]
|
||||
) -> None: ...
|
||||
|
||||
@gather_metrics("query_params.from_dict")
|
||||
def from_dict(self, params):
|
||||
"""
|
||||
Set all of the query parameters from a dictionary or dictionary-like object.
|
||||
|
||||
This method primarily exists for advanced users who want to control
|
||||
multiple query parameters in a single update. To set individual query
|
||||
parameters, use key or attribute notation instead.
|
||||
|
||||
This method inherits limitations from ``st.query_params`` and can't be
|
||||
used to set embedding options as described in `Embed your app \
|
||||
<https://docs.streamlit.io/deploy/streamlit-community-cloud/share-your-app/embed-your-app#embed-options>`_.
|
||||
|
||||
To handle repeated keys, the value in a key-value pair should be a list.
|
||||
|
||||
.. note::
|
||||
``.from_dict()`` is not a direct inverse of ``.to_dict()`` if
|
||||
you are working with repeated keys. A true inverse operation is
|
||||
``{key: st.query_params.get_all(key) for key in st.query_params}``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params: dict
|
||||
A dictionary used to replace the current query parameters.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> st.query_params.from_dict({"foo": "bar", "baz": [1, "two"]})
|
||||
|
||||
"""
|
||||
with get_session_state().query_params() as qp:
|
||||
return qp.from_dict(params)
|
||||
|
||||
@staticmethod
|
||||
def missing_key_error_message(key: str) -> str:
|
||||
return f'st.query_params has no key "{key}".'
|
||||
|
||||
@staticmethod
|
||||
def missing_attr_error_message(key: str) -> str:
|
||||
return f'st.query_params has no attribute "{key}".'
|
||||
@@ -0,0 +1,138 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
from streamlit.proto.WidgetStates_pb2 import WidgetState as WidgetStateProto
|
||||
from streamlit.proto.WidgetStates_pb2 import WidgetStates as WidgetStatesProto
|
||||
from streamlit.runtime.state.common import RegisterWidgetResult, T, WidgetMetadata
|
||||
from streamlit.runtime.state.query_params import QueryParams
|
||||
from streamlit.runtime.state.session_state import SessionState
|
||||
|
||||
|
||||
class SafeSessionState:
|
||||
"""Thread-safe wrapper around SessionState.
|
||||
|
||||
When AppSession gets a re-run request, it can interrupt its existing
|
||||
ScriptRunner and spin up a new ScriptRunner to handle the request.
|
||||
When this happens, the existing ScriptRunner will continue executing
|
||||
its script until it reaches a yield point - but during this time, it
|
||||
must not mutate its SessionState.
|
||||
"""
|
||||
|
||||
_state: SessionState
|
||||
_lock: threading.RLock
|
||||
_yield_callback: Callable[[], None]
|
||||
|
||||
def __init__(self, state: SessionState, yield_callback: Callable[[], None]):
|
||||
# Fields must be set using the object's setattr method to avoid
|
||||
# infinite recursion from trying to look up the fields we're setting.
|
||||
object.__setattr__(self, "_state", state)
|
||||
# TODO: we'd prefer this be a threading.Lock instead of RLock -
|
||||
# but `call_callbacks` first needs to be rewritten.
|
||||
object.__setattr__(self, "_lock", threading.RLock())
|
||||
object.__setattr__(self, "_yield_callback", yield_callback)
|
||||
|
||||
def register_widget(
|
||||
self, metadata: WidgetMetadata[T], user_key: str | None
|
||||
) -> RegisterWidgetResult[T]:
|
||||
self._yield_callback()
|
||||
with self._lock:
|
||||
return self._state.register_widget(metadata, user_key)
|
||||
|
||||
def on_script_will_rerun(self, latest_widget_states: WidgetStatesProto) -> None:
|
||||
self._yield_callback()
|
||||
with self._lock:
|
||||
# TODO: rewrite this to copy the callbacks list into a local
|
||||
# variable so that we don't need to hold our lock for the
|
||||
# duration. (This will also allow us to downgrade our RLock
|
||||
# to a Lock.)
|
||||
self._state.on_script_will_rerun(latest_widget_states)
|
||||
|
||||
def on_script_finished(self, widget_ids_this_run: set[str]) -> None:
|
||||
with self._lock:
|
||||
self._state.on_script_finished(widget_ids_this_run)
|
||||
|
||||
def maybe_check_serializable(self) -> None:
|
||||
with self._lock:
|
||||
self._state.maybe_check_serializable()
|
||||
|
||||
def get_widget_states(self) -> list[WidgetStateProto]:
|
||||
"""Return a list of serialized widget values for each widget with a value."""
|
||||
with self._lock:
|
||||
return self._state.get_widget_states()
|
||||
|
||||
def is_new_state_value(self, user_key: str) -> bool:
|
||||
with self._lock:
|
||||
return self._state.is_new_state_value(user_key)
|
||||
|
||||
@property
|
||||
def filtered_state(self) -> dict[str, Any]:
|
||||
"""The combined session and widget state, excluding keyless widgets."""
|
||||
with self._lock:
|
||||
return self._state.filtered_state
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
self._yield_callback()
|
||||
with self._lock:
|
||||
return self._state[key]
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
self._yield_callback()
|
||||
with self._lock:
|
||||
self._state[key] = value
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
self._yield_callback()
|
||||
with self._lock:
|
||||
del self._state[key]
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
self._yield_callback()
|
||||
with self._lock:
|
||||
return key in self._state
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
raise AttributeError(f"{key} not found in session_state.")
|
||||
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
self[key] = value
|
||||
|
||||
def __delattr__(self, key: str) -> None:
|
||||
try:
|
||||
del self[key]
|
||||
except KeyError:
|
||||
raise AttributeError(f"{key} not found in session_state.")
|
||||
|
||||
def __repr__(self):
|
||||
"""Presents itself as a simple dict of the underlying SessionState instance."""
|
||||
kv = ((k, self._state[k]) for k in self._state._keys())
|
||||
s = ", ".join(f"{k}: {v!r}" for k, v in kv)
|
||||
return f"{{{s}}}"
|
||||
|
||||
@contextmanager
|
||||
def query_params(self) -> Iterator[QueryParams]:
|
||||
self._yield_callback()
|
||||
with self._lock:
|
||||
yield self._state.query_params
|
||||
@@ -0,0 +1,773 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pickle
|
||||
from collections.abc import Iterator, KeysView, MutableMapping
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Final,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import streamlit as st
|
||||
from streamlit import config, util
|
||||
from streamlit.errors import StreamlitAPIException, UnserializableSessionStateError
|
||||
from streamlit.proto.WidgetStates_pb2 import WidgetState as WidgetStateProto
|
||||
from streamlit.proto.WidgetStates_pb2 import WidgetStates as WidgetStatesProto
|
||||
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx
|
||||
from streamlit.runtime.state.common import (
|
||||
RegisterWidgetResult,
|
||||
T,
|
||||
ValueFieldName,
|
||||
WidgetMetadata,
|
||||
is_array_value_field_name,
|
||||
is_element_id,
|
||||
is_keyed_element_id,
|
||||
)
|
||||
from streamlit.runtime.state.query_params import QueryParams
|
||||
from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.runtime.session_manager import SessionManager
|
||||
|
||||
|
||||
STREAMLIT_INTERNAL_KEY_PREFIX: Final = "$$STREAMLIT_INTERNAL_KEY"
|
||||
SCRIPT_RUN_WITHOUT_ERRORS_KEY: Final = (
|
||||
f"{STREAMLIT_INTERNAL_KEY_PREFIX}_SCRIPT_RUN_WITHOUT_ERRORS"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Serialized:
|
||||
"""A widget value that's serialized to a protobuf. Immutable."""
|
||||
|
||||
value: WidgetStateProto
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Value:
|
||||
"""A widget value that's not serialized. Immutable."""
|
||||
|
||||
value: Any
|
||||
|
||||
|
||||
WState: TypeAlias = Union[Value, Serialized]
|
||||
|
||||
|
||||
@dataclass
|
||||
class WStates(MutableMapping[str, Any]):
|
||||
"""A mapping of widget IDs to values. Widget values can be stored in
|
||||
serialized or deserialized form, but when values are retrieved from the
|
||||
mapping, they'll always be deserialized.
|
||||
"""
|
||||
|
||||
states: dict[str, WState] = field(default_factory=dict)
|
||||
widget_metadata: dict[str, WidgetMetadata[Any]] = field(default_factory=dict)
|
||||
|
||||
def __repr__(self):
|
||||
return util.repr_(self)
|
||||
|
||||
def __getitem__(self, k: str) -> Any:
|
||||
"""Return the value of the widget with the given key.
|
||||
If the widget's value is currently stored in serialized form, it
|
||||
will be deserialized first.
|
||||
"""
|
||||
wstate = self.states.get(k)
|
||||
if wstate is None:
|
||||
raise KeyError(k)
|
||||
|
||||
if isinstance(wstate, Value):
|
||||
# The widget's value is already deserialized - return it directly.
|
||||
return wstate.value
|
||||
|
||||
# The widget's value is serialized. We deserialize it, and return
|
||||
# the deserialized value.
|
||||
|
||||
metadata = self.widget_metadata.get(k)
|
||||
if metadata is None:
|
||||
# No deserializer, which should only happen if state is
|
||||
# gotten from a reconnecting browser and the script is
|
||||
# trying to access it. Pretend it doesn't exist.
|
||||
raise KeyError(k)
|
||||
value_field_name = cast(
|
||||
"ValueFieldName",
|
||||
wstate.value.WhichOneof("value"),
|
||||
)
|
||||
value = (
|
||||
wstate.value.__getattribute__(value_field_name)
|
||||
if value_field_name # Field name is None if the widget value was cleared
|
||||
else None
|
||||
)
|
||||
|
||||
if is_array_value_field_name(value_field_name):
|
||||
# Array types are messages with data in a `data` field
|
||||
value = value.data
|
||||
elif value_field_name == "json_value":
|
||||
value = json.loads(value)
|
||||
|
||||
deserialized = metadata.deserializer(value, metadata.id)
|
||||
|
||||
# Update metadata to reflect information from WidgetState proto
|
||||
self.set_widget_metadata(
|
||||
replace(
|
||||
metadata,
|
||||
value_type=value_field_name,
|
||||
)
|
||||
)
|
||||
|
||||
self.states[k] = Value(deserialized)
|
||||
return deserialized
|
||||
|
||||
def __setitem__(self, k: str, v: WState) -> None:
|
||||
self.states[k] = v
|
||||
|
||||
def __delitem__(self, k: str) -> None:
|
||||
del self.states[k]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.states)
|
||||
|
||||
def __iter__(self):
|
||||
# For this and many other methods, we can't simply delegate to the
|
||||
# states field, because we need to invoke `__getitem__` for any
|
||||
# values, to handle deserialization and unwrapping of values.
|
||||
yield from self.states
|
||||
|
||||
def keys(self) -> KeysView[str]:
|
||||
return KeysView(self.states)
|
||||
|
||||
def items(self) -> set[tuple[str, Any]]: # type: ignore[override]
|
||||
return {(k, self[k]) for k in self}
|
||||
|
||||
def values(self) -> set[Any]: # type: ignore[override]
|
||||
return {self[wid] for wid in self}
|
||||
|
||||
def update(self, other: WStates) -> None: # type: ignore[override]
|
||||
"""Copy all widget values and metadata from 'other' into this mapping,
|
||||
overwriting any data in this mapping that's also present in 'other'.
|
||||
"""
|
||||
self.states.update(other.states)
|
||||
self.widget_metadata.update(other.widget_metadata)
|
||||
|
||||
def set_widget_from_proto(self, widget_state: WidgetStateProto) -> None:
|
||||
"""Set a widget's serialized value, overwriting any existing value it has."""
|
||||
self[widget_state.id] = Serialized(widget_state)
|
||||
|
||||
def set_from_value(self, k: str, v: Any) -> None:
|
||||
"""Set a widget's deserialized value, overwriting any existing value it has."""
|
||||
self[k] = Value(v)
|
||||
|
||||
def set_widget_metadata(self, widget_meta: WidgetMetadata[Any]) -> None:
|
||||
"""Set a widget's metadata, overwriting any existing metadata it has."""
|
||||
self.widget_metadata[widget_meta.id] = widget_meta
|
||||
|
||||
def remove_stale_widgets(
|
||||
self,
|
||||
active_widget_ids: set[str],
|
||||
fragment_ids_this_run: list[str] | None,
|
||||
) -> None:
|
||||
"""Remove widget state for stale widgets."""
|
||||
self.states = {
|
||||
k: v
|
||||
for k, v in self.states.items()
|
||||
if not _is_stale_widget(
|
||||
self.widget_metadata.get(k),
|
||||
active_widget_ids,
|
||||
fragment_ids_this_run,
|
||||
)
|
||||
}
|
||||
|
||||
def get_serialized(self, k: str) -> WidgetStateProto | None:
|
||||
"""Get the serialized value of the widget with the given id.
|
||||
|
||||
If the widget doesn't exist, return None. If the widget exists but
|
||||
is not in serialized form, it will be serialized first.
|
||||
"""
|
||||
|
||||
item = self.states.get(k)
|
||||
if item is None:
|
||||
# No such widget: return None.
|
||||
return None
|
||||
|
||||
if isinstance(item, Serialized):
|
||||
# Widget value is serialized: return it directly.
|
||||
return item.value
|
||||
|
||||
# Widget value is not serialized: serialize it first!
|
||||
metadata = self.widget_metadata.get(k)
|
||||
if metadata is None:
|
||||
# We're missing the widget's metadata. (Can this happen?)
|
||||
return None
|
||||
|
||||
widget = WidgetStateProto()
|
||||
widget.id = k
|
||||
|
||||
field = metadata.value_type
|
||||
serialized = metadata.serializer(item.value)
|
||||
|
||||
if is_array_value_field_name(field):
|
||||
arr = getattr(widget, field)
|
||||
arr.data.extend(serialized)
|
||||
elif field == "json_value":
|
||||
setattr(widget, field, json.dumps(serialized))
|
||||
elif field == "file_uploader_state_value":
|
||||
widget.file_uploader_state_value.CopyFrom(serialized)
|
||||
elif field == "string_trigger_value":
|
||||
widget.string_trigger_value.CopyFrom(serialized)
|
||||
elif field == "chat_input_value":
|
||||
widget.chat_input_value.CopyFrom(serialized)
|
||||
elif field is not None and serialized is not None:
|
||||
# If the field is None, the widget value was cleared
|
||||
# by the user and therefore is None. But we cannot
|
||||
# set it to None here, since the proto properties are
|
||||
# not nullable. So we just don't set it.
|
||||
setattr(widget, field, serialized)
|
||||
|
||||
return widget
|
||||
|
||||
def as_widget_states(self) -> list[WidgetStateProto]:
|
||||
"""Return a list of serialized widget values for each widget with a value."""
|
||||
states = [
|
||||
self.get_serialized(widget_id)
|
||||
for widget_id in self.states.keys()
|
||||
if self.get_serialized(widget_id)
|
||||
]
|
||||
states = cast("list[WidgetStateProto]", states)
|
||||
return states
|
||||
|
||||
def call_callback(self, widget_id: str) -> None:
|
||||
"""Call the given widget's callback and return the callback's
|
||||
return value. If the widget has no callback, return None.
|
||||
|
||||
If the widget doesn't exist, raise an Exception.
|
||||
"""
|
||||
metadata = self.widget_metadata.get(widget_id)
|
||||
assert metadata is not None
|
||||
callback = metadata.callback
|
||||
if callback is None:
|
||||
return
|
||||
|
||||
args = metadata.callback_args or ()
|
||||
kwargs = metadata.callback_kwargs or {}
|
||||
callback(*args, **kwargs)
|
||||
|
||||
|
||||
def _missing_key_error_message(key: str) -> str:
|
||||
return (
|
||||
f'st.session_state has no key "{key}". Did you forget to initialize it? '
|
||||
f"More info: https://docs.streamlit.io/develop/concepts/architecture/session-state#initialization"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeyIdMapper:
|
||||
"""A mapping of user-provided keys to element IDs.
|
||||
It also maps element IDs to user-provided keys so that this reverse mapping
|
||||
does not have to be computed ad-hoc.
|
||||
All built-in dict-operations such as setting and deleting expect the key as the
|
||||
argument, not the element ID.
|
||||
"""
|
||||
|
||||
_key_id_mapping: dict[str, str] = field(default_factory=dict)
|
||||
_id_key_mapping: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self._key_id_mapping
|
||||
|
||||
def __setitem__(self, key: str, widget_id: Any) -> None:
|
||||
self._key_id_mapping[key] = widget_id
|
||||
self._id_key_mapping[widget_id] = key
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
self.delete(key)
|
||||
|
||||
@property
|
||||
def id_key_mapping(self) -> dict[str, str]:
|
||||
return self._id_key_mapping
|
||||
|
||||
def set_key_id_mapping(self, key_id_mapping: dict[str, str]) -> None:
|
||||
self._key_id_mapping = key_id_mapping
|
||||
self._id_key_mapping = {v: k for k, v in key_id_mapping.items()}
|
||||
|
||||
def get_id_from_key(self, key: str, default: Any = None) -> str:
|
||||
return self._key_id_mapping.get(key, default)
|
||||
|
||||
def get_key_from_id(self, widget_id: str) -> str:
|
||||
return self._id_key_mapping[widget_id]
|
||||
|
||||
def update(self, other: KeyIdMapper) -> None:
|
||||
self._key_id_mapping.update(other._key_id_mapping)
|
||||
self._id_key_mapping.update(other._id_key_mapping)
|
||||
|
||||
def clear(self):
|
||||
self._key_id_mapping.clear()
|
||||
self._id_key_mapping.clear()
|
||||
|
||||
def delete(self, key: str):
|
||||
widget_id = self._key_id_mapping[key]
|
||||
del self._key_id_mapping[key]
|
||||
del self._id_key_mapping[widget_id]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionState:
|
||||
"""SessionState allows users to store values that persist between app
|
||||
reruns.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> if "num_script_runs" not in st.session_state:
|
||||
... st.session_state.num_script_runs = 0
|
||||
>>> st.session_state.num_script_runs += 1
|
||||
>>> st.write(st.session_state.num_script_runs) # writes 1
|
||||
|
||||
The next time your script runs, the value of
|
||||
st.session_state.num_script_runs will be preserved.
|
||||
>>> st.session_state.num_script_runs += 1
|
||||
>>> st.write(st.session_state.num_script_runs) # writes 2
|
||||
"""
|
||||
|
||||
# All the values from previous script runs, squished together to save memory
|
||||
_old_state: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Values set in session state during the current script run, possibly for
|
||||
# setting a widget's value. Keyed by a user provided string.
|
||||
_new_session_state: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Widget values from the frontend, usually one changing prompted the script rerun
|
||||
_new_widget_state: WStates = field(default_factory=WStates)
|
||||
|
||||
# Keys used for widgets will be eagerly converted to the matching element id
|
||||
_key_id_mapper: KeyIdMapper = field(default_factory=KeyIdMapper)
|
||||
|
||||
# query params are stored in session state because query params will be tied with
|
||||
# widget state at one point.
|
||||
query_params: QueryParams = field(default_factory=QueryParams)
|
||||
|
||||
def __repr__(self):
|
||||
return util.repr_(self)
|
||||
|
||||
# is it possible for a value to get through this without being deserialized?
|
||||
def _compact_state(self) -> None:
|
||||
"""Copy all current session_state and widget_state values into our
|
||||
_old_state dict, and then clear our current session_state and
|
||||
widget_state.
|
||||
"""
|
||||
for key_or_wid in self:
|
||||
try:
|
||||
self._old_state[key_or_wid] = self[key_or_wid]
|
||||
except KeyError:
|
||||
# handle key errors from widget state not having metadata gracefully
|
||||
# https://github.com/streamlit/streamlit/issues/7206
|
||||
pass
|
||||
self._new_session_state.clear()
|
||||
self._new_widget_state.clear()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Reset self completely, clearing all current and old values."""
|
||||
self._old_state.clear()
|
||||
self._new_session_state.clear()
|
||||
self._new_widget_state.clear()
|
||||
self._key_id_mapper.clear()
|
||||
|
||||
@property
|
||||
def filtered_state(self) -> dict[str, Any]:
|
||||
"""The combined session and widget state, excluding keyless widgets."""
|
||||
|
||||
wid_key_map = self._key_id_mapper.id_key_mapping
|
||||
|
||||
state: dict[str, Any] = {}
|
||||
|
||||
# We can't write `for k, v in self.items()` here because doing so will
|
||||
# run into a `KeyError` if widget metadata has been cleared (which
|
||||
# happens when the streamlit server restarted or the cache was cleared),
|
||||
# then we receive a widget's state from a browser.
|
||||
for k in self._keys():
|
||||
if not is_element_id(k) and not _is_internal_key(k):
|
||||
state[k] = self[k]
|
||||
elif is_keyed_element_id(k):
|
||||
try:
|
||||
key = wid_key_map[k]
|
||||
state[key] = self[k]
|
||||
except KeyError:
|
||||
# Widget id no longer maps to a key, it is a not yet
|
||||
# cleared value in old state for a reset widget
|
||||
pass
|
||||
|
||||
return state
|
||||
|
||||
def _keys(self) -> set[str]:
|
||||
"""All keys active in Session State, with widget keys converted
|
||||
to widget ids when one is known. (This includes autogenerated keys
|
||||
for widgets that don't have user_keys defined, and which aren't
|
||||
exposed to user code).
|
||||
"""
|
||||
old_keys = {self._get_widget_id(k) for k in self._old_state.keys()}
|
||||
new_widget_keys = set(self._new_widget_state.keys())
|
||||
new_session_state_keys = {
|
||||
self._get_widget_id(k) for k in self._new_session_state.keys()
|
||||
}
|
||||
return old_keys | new_widget_keys | new_session_state_keys
|
||||
|
||||
def is_new_state_value(self, user_key: str) -> bool:
|
||||
"""True if a value with the given key is in the current session state."""
|
||||
return user_key in self._new_session_state
|
||||
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
"""Return an iterator over the keys of the SessionState.
|
||||
This is a shortcut for `iter(self.keys())`.
|
||||
"""
|
||||
return iter(self._keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of items in SessionState."""
|
||||
return len(self._keys())
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
wid_key_map = self._key_id_mapper.id_key_mapping
|
||||
widget_id = self._get_widget_id(key)
|
||||
|
||||
if widget_id in wid_key_map and widget_id == key:
|
||||
# the "key" is a raw widget id, so get its associated user key for lookup
|
||||
key = wid_key_map[widget_id]
|
||||
try:
|
||||
return self._getitem(widget_id, key)
|
||||
except KeyError:
|
||||
raise KeyError(_missing_key_error_message(key))
|
||||
|
||||
def _getitem(self, widget_id: str | None, user_key: str | None) -> Any:
|
||||
"""Get the value of an entry in Session State, using either the
|
||||
user-provided key or a widget id as appropriate for the internal dict
|
||||
being accessed.
|
||||
|
||||
At least one of the arguments must have a value.
|
||||
"""
|
||||
assert user_key is not None or widget_id is not None
|
||||
|
||||
if user_key is not None:
|
||||
try:
|
||||
return self._new_session_state[user_key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if widget_id is not None:
|
||||
try:
|
||||
return self._new_widget_state[widget_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# Typically, there won't be both a widget id and an associated state key in
|
||||
# old state at the same time, so the order we check is arbitrary.
|
||||
# The exception is if session state is set and then a later run has
|
||||
# a widget created, so the widget id entry should be newer.
|
||||
# The opposite case shouldn't happen, because setting the value of a widget
|
||||
# through session state will result in the next widget state reflecting that
|
||||
# value.
|
||||
if widget_id is not None:
|
||||
try:
|
||||
return self._old_state[widget_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if user_key is not None:
|
||||
try:
|
||||
return self._old_state[user_key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# We'll never get here
|
||||
raise KeyError
|
||||
|
||||
def __setitem__(self, user_key: str, value: Any) -> None:
|
||||
"""Set the value of the session_state entry with the given user_key.
|
||||
|
||||
If the key corresponds to a widget or form that's been instantiated
|
||||
during the current script run, raise a StreamlitAPIException instead.
|
||||
"""
|
||||
ctx = get_script_run_ctx()
|
||||
|
||||
if ctx is not None:
|
||||
widget_id = self._key_id_mapper.get_id_from_key(user_key, None)
|
||||
widget_ids = ctx.widget_ids_this_run
|
||||
form_ids = ctx.form_ids_this_run
|
||||
|
||||
if widget_id in widget_ids or user_key in form_ids:
|
||||
raise StreamlitAPIException(
|
||||
f"`st.session_state.{user_key}` cannot be modified after the widget"
|
||||
f" with key `{user_key}` is instantiated."
|
||||
)
|
||||
|
||||
self._new_session_state[user_key] = value
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
widget_id = self._get_widget_id(key)
|
||||
|
||||
if not (key in self or widget_id in self):
|
||||
raise KeyError(_missing_key_error_message(key))
|
||||
|
||||
if key in self._new_session_state:
|
||||
del self._new_session_state[key]
|
||||
|
||||
if key in self._old_state:
|
||||
del self._old_state[key]
|
||||
|
||||
if key in self._key_id_mapper:
|
||||
self._key_id_mapper.delete(key)
|
||||
|
||||
if widget_id in self._new_widget_state:
|
||||
del self._new_widget_state[widget_id]
|
||||
|
||||
if widget_id in self._old_state:
|
||||
del self._old_state[widget_id]
|
||||
|
||||
def set_widgets_from_proto(self, widget_states: WidgetStatesProto) -> None:
|
||||
"""Set the value of all widgets represented in the given WidgetStatesProto."""
|
||||
for state in widget_states.widgets:
|
||||
self._new_widget_state.set_widget_from_proto(state)
|
||||
|
||||
def on_script_will_rerun(self, latest_widget_states: WidgetStatesProto) -> None:
|
||||
"""Called by ScriptRunner before its script re-runs.
|
||||
|
||||
Update widget data and call callbacks on widgets whose value changed
|
||||
between the previous and current script runs.
|
||||
"""
|
||||
# Clear any triggers that weren't reset because the script was disconnected
|
||||
self._reset_triggers()
|
||||
self._compact_state()
|
||||
self.set_widgets_from_proto(latest_widget_states)
|
||||
self._call_callbacks()
|
||||
|
||||
def _call_callbacks(self) -> None:
|
||||
"""Call any callback associated with each widget whose value
|
||||
changed between the previous and current script runs.
|
||||
"""
|
||||
from streamlit.runtime.scriptrunner import RerunException
|
||||
|
||||
changed_widget_ids = [
|
||||
wid for wid in self._new_widget_state if self._widget_changed(wid)
|
||||
]
|
||||
for wid in changed_widget_ids:
|
||||
try:
|
||||
self._new_widget_state.call_callback(wid)
|
||||
except RerunException:
|
||||
st.warning("Calling st.rerun() within a callback is a no-op.")
|
||||
|
||||
def _widget_changed(self, widget_id: str) -> bool:
|
||||
"""True if the given widget's value changed between the previous
|
||||
script run and the current script run.
|
||||
"""
|
||||
new_value = self._new_widget_state.get(widget_id)
|
||||
old_value = self._old_state.get(widget_id)
|
||||
changed: bool = new_value != old_value
|
||||
return changed
|
||||
|
||||
def on_script_finished(self, widget_ids_this_run: set[str]) -> None:
|
||||
"""Called by ScriptRunner after its script finishes running.
|
||||
Updates widgets to prepare for the next script run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
widget_ids_this_run: set[str]
|
||||
The IDs of the widgets that were accessed during the script
|
||||
run. Any widget state whose ID does *not* appear in this set
|
||||
is considered "stale" and will be removed.
|
||||
"""
|
||||
self._reset_triggers()
|
||||
self._remove_stale_widgets(widget_ids_this_run)
|
||||
|
||||
def _reset_triggers(self) -> None:
|
||||
"""Set all trigger values in our state dictionary to False."""
|
||||
for state_id in self._new_widget_state:
|
||||
metadata = self._new_widget_state.widget_metadata.get(state_id)
|
||||
if metadata is not None:
|
||||
if metadata.value_type == "trigger_value":
|
||||
self._new_widget_state[state_id] = Value(False)
|
||||
elif metadata.value_type == "string_trigger_value":
|
||||
self._new_widget_state[state_id] = Value(None)
|
||||
elif metadata.value_type == "chat_input_value":
|
||||
self._new_widget_state[state_id] = Value(None)
|
||||
|
||||
for state_id in self._old_state:
|
||||
metadata = self._new_widget_state.widget_metadata.get(state_id)
|
||||
if metadata is not None:
|
||||
if metadata.value_type == "trigger_value":
|
||||
self._old_state[state_id] = False
|
||||
elif metadata.value_type == "string_trigger_value":
|
||||
self._old_state[state_id] = None
|
||||
elif metadata.value_type == "chat_input_value":
|
||||
self._old_state[state_id] = None
|
||||
|
||||
def _remove_stale_widgets(self, active_widget_ids: set[str]) -> None:
|
||||
"""Remove widget state for widgets whose ids aren't in `active_widget_ids`."""
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is None:
|
||||
return
|
||||
|
||||
self._new_widget_state.remove_stale_widgets(
|
||||
active_widget_ids,
|
||||
ctx.fragment_ids_this_run,
|
||||
)
|
||||
|
||||
# Remove entries from _old_state corresponding to
|
||||
# widgets not in widget_ids.
|
||||
self._old_state = {
|
||||
k: v
|
||||
for k, v in self._old_state.items()
|
||||
if (
|
||||
not is_element_id(k)
|
||||
or not _is_stale_widget(
|
||||
self._new_widget_state.widget_metadata.get(k),
|
||||
active_widget_ids,
|
||||
ctx.fragment_ids_this_run,
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
def _set_widget_metadata(self, widget_metadata: WidgetMetadata[Any]) -> None:
|
||||
"""Set a widget's metadata."""
|
||||
widget_id = widget_metadata.id
|
||||
self._new_widget_state.widget_metadata[widget_id] = widget_metadata
|
||||
|
||||
def get_widget_states(self) -> list[WidgetStateProto]:
|
||||
"""Return a list of serialized widget values for each widget with a value."""
|
||||
return self._new_widget_state.as_widget_states()
|
||||
|
||||
def _get_widget_id(self, k: str) -> str:
|
||||
"""Turns a value that might be a widget id or a user provided key into
|
||||
an appropriate widget id.
|
||||
"""
|
||||
return self._key_id_mapper.get_id_from_key(k, k)
|
||||
|
||||
def _set_key_widget_mapping(self, widget_id: str, user_key: str) -> None:
|
||||
self._key_id_mapper[user_key] = widget_id
|
||||
|
||||
def register_widget(
|
||||
self, metadata: WidgetMetadata[T], user_key: str | None
|
||||
) -> RegisterWidgetResult[T]:
|
||||
"""Register a widget with the SessionState.
|
||||
|
||||
Returns
|
||||
-------
|
||||
RegisterWidgetResult[T]
|
||||
Contains the widget's current value, and a bool that will be True
|
||||
if the frontend needs to be updated with the current value.
|
||||
"""
|
||||
widget_id = metadata.id
|
||||
|
||||
self._set_widget_metadata(metadata)
|
||||
if user_key is not None:
|
||||
# If the widget has a user_key, update its user_key:widget_id mapping
|
||||
self._set_key_widget_mapping(widget_id, user_key)
|
||||
|
||||
if widget_id not in self and (user_key is None or user_key not in self):
|
||||
# This is the first time the widget is registered, so we save its
|
||||
# value in widget state.
|
||||
deserializer = metadata.deserializer
|
||||
initial_widget_value = deepcopy(deserializer(None, metadata.id))
|
||||
self._new_widget_state.set_from_value(widget_id, initial_widget_value)
|
||||
|
||||
# Get the current value of the widget for use as its return value.
|
||||
# We return a copy, so that reference types can't be accidentally
|
||||
# mutated by user code.
|
||||
widget_value = cast("T", self[widget_id])
|
||||
widget_value = deepcopy(widget_value)
|
||||
|
||||
# widget_value_changed indicates to the caller that the widget's
|
||||
# current value is different from what is in the frontend.
|
||||
widget_value_changed = user_key is not None and self.is_new_state_value(
|
||||
user_key
|
||||
)
|
||||
|
||||
return RegisterWidgetResult(widget_value, widget_value_changed)
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
try:
|
||||
self[key]
|
||||
except KeyError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
# Lazy-load vendored package to prevent import of numpy
|
||||
from streamlit.vendor.pympler.asizeof import asizeof
|
||||
|
||||
stat = CacheStat("st_session_state", "", asizeof(self))
|
||||
return [stat]
|
||||
|
||||
def _check_serializable(self) -> None:
|
||||
"""Verify that everything added to session state can be serialized.
|
||||
We use pickleability as the metric for serializability, and test for
|
||||
pickleability by just trying it.
|
||||
"""
|
||||
for k in self:
|
||||
try:
|
||||
pickle.dumps(self[k])
|
||||
except Exception as e:
|
||||
err_msg = f"""Cannot serialize the value (of type `{type(self[k])}`) of '{k}' in st.session_state.
|
||||
Streamlit has been configured to use [pickle](https://docs.python.org/3/library/pickle.html) to
|
||||
serialize session_state values. Please convert the value to a pickle-serializable type. To learn
|
||||
more about this behavior, see [our docs](https://docs.streamlit.io/knowledge-base/using-streamlit/serializable-session-state). """
|
||||
raise UnserializableSessionStateError(err_msg) from e
|
||||
|
||||
def maybe_check_serializable(self) -> None:
|
||||
"""Verify that session state can be serialized, if the relevant config
|
||||
option is set.
|
||||
|
||||
See `_check_serializable` for details.
|
||||
"""
|
||||
if config.get_option("runner.enforceSerializableSessionState"):
|
||||
self._check_serializable()
|
||||
|
||||
|
||||
def _is_internal_key(key: str) -> bool:
|
||||
return key.startswith(STREAMLIT_INTERNAL_KEY_PREFIX)
|
||||
|
||||
|
||||
def _is_stale_widget(
|
||||
metadata: WidgetMetadata[Any] | None,
|
||||
active_widget_ids: set[str],
|
||||
fragment_ids_this_run: list[str] | None,
|
||||
) -> bool:
|
||||
if not metadata:
|
||||
return True
|
||||
elif metadata.id in active_widget_ids:
|
||||
return False
|
||||
# If we're running 1 or more fragments, but this widget is unrelated to any of the
|
||||
# fragments that we're running, then it should not be marked as stale as its value
|
||||
# may still be needed for a future fragment run or full script run.
|
||||
elif fragment_ids_this_run and metadata.fragment_id not in fragment_ids_this_run:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionStateStatProvider(CacheStatsProvider):
|
||||
_session_mgr: SessionManager
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
stats: list[CacheStat] = []
|
||||
for session_info in self._session_mgr.list_active_sessions():
|
||||
session_state = session_info.session.session_state
|
||||
stats.extend(session_state.get_stats())
|
||||
return group_stats(stats)
|
||||
@@ -0,0 +1,153 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, MutableMapping
|
||||
from typing import Any, Final
|
||||
|
||||
from streamlit import logger as _logger
|
||||
from streamlit import runtime
|
||||
from streamlit.elements.lib.utils import Key
|
||||
from streamlit.runtime.metrics_util import gather_metrics
|
||||
from streamlit.runtime.state.common import require_valid_user_key
|
||||
from streamlit.runtime.state.safe_session_state import SafeSessionState
|
||||
from streamlit.runtime.state.session_state import SessionState
|
||||
|
||||
_LOGGER: Final = _logger.get_logger(__name__)
|
||||
|
||||
|
||||
_state_use_warning_already_displayed: bool = False
|
||||
# The mock session state is used as a fallback if the script is run without `streamlit run`
|
||||
_mock_session_state: SafeSessionState | None = None
|
||||
|
||||
|
||||
def get_session_state() -> SafeSessionState:
|
||||
"""Get the SessionState object for the current session.
|
||||
|
||||
Note that in streamlit scripts, this function should not be called
|
||||
directly. Instead, SessionState objects should be accessed via
|
||||
st.session_state.
|
||||
"""
|
||||
global _state_use_warning_already_displayed
|
||||
from streamlit.runtime.scriptrunner_utils.script_run_context import (
|
||||
get_script_run_ctx,
|
||||
)
|
||||
|
||||
ctx = get_script_run_ctx()
|
||||
|
||||
# If there is no script run context because the script is run bare, we
|
||||
# use a global mock session state version to allow bare script execution (via python script.py)
|
||||
if ctx is None:
|
||||
if not _state_use_warning_already_displayed:
|
||||
_state_use_warning_already_displayed = True
|
||||
if not runtime.exists():
|
||||
_LOGGER.warning(
|
||||
"Session state does not function when running a script without `streamlit run`"
|
||||
)
|
||||
|
||||
global _mock_session_state
|
||||
|
||||
if _mock_session_state is None:
|
||||
# Lazy initialize the mock session state
|
||||
_mock_session_state = SafeSessionState(SessionState(), lambda: None)
|
||||
return _mock_session_state
|
||||
return ctx.session_state
|
||||
|
||||
|
||||
class SessionStateProxy(MutableMapping[Key, Any]):
|
||||
"""A stateless singleton that proxies `st.session_state` interactions
|
||||
to the current script thread's SessionState instance.
|
||||
|
||||
The proxy API differs slightly from SessionState: it does not allow
|
||||
callers to get, set, or iterate over "keyless" widgets (that is, widgets
|
||||
that were created without a user_key, and have autogenerated keys).
|
||||
"""
|
||||
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
"""Iterator over user state and keyed widget values."""
|
||||
# TODO: this is unsafe if fastReruns is true! Let's deprecate/remove.
|
||||
return iter(get_session_state().filtered_state)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Number of user state and keyed widget values in session_state."""
|
||||
return len(get_session_state().filtered_state)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of user state and keyed widget values."""
|
||||
return str(get_session_state().filtered_state)
|
||||
|
||||
def __getitem__(self, key: Key) -> Any:
|
||||
"""Return the state or widget value with the given key.
|
||||
|
||||
Raises
|
||||
------
|
||||
StreamlitAPIException
|
||||
If the key is not a valid SessionState user key.
|
||||
"""
|
||||
key = str(key)
|
||||
require_valid_user_key(key)
|
||||
return get_session_state()[key]
|
||||
|
||||
@gather_metrics("session_state.set_item")
|
||||
def __setitem__(self, key: Key, value: Any) -> None:
|
||||
"""Set the value of the given key.
|
||||
|
||||
Raises
|
||||
------
|
||||
StreamlitAPIException
|
||||
If the key is not a valid SessionState user key.
|
||||
"""
|
||||
key = str(key)
|
||||
require_valid_user_key(key)
|
||||
get_session_state()[key] = value
|
||||
|
||||
def __delitem__(self, key: Key) -> None:
|
||||
"""Delete the value with the given key.
|
||||
|
||||
Raises
|
||||
------
|
||||
StreamlitAPIException
|
||||
If the key is not a valid SessionState user key.
|
||||
"""
|
||||
key = str(key)
|
||||
require_valid_user_key(key)
|
||||
del get_session_state()[key]
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
raise AttributeError(_missing_attr_error_message(key))
|
||||
|
||||
@gather_metrics("session_state.set_attr")
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
self[key] = value
|
||||
|
||||
def __delattr__(self, key: str) -> None:
|
||||
try:
|
||||
del self[key]
|
||||
except KeyError:
|
||||
raise AttributeError(_missing_attr_error_message(key))
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return a dict containing all session_state and keyed widget values."""
|
||||
return get_session_state().filtered_state
|
||||
|
||||
|
||||
def _missing_attr_error_message(attr_name: str) -> str:
|
||||
return (
|
||||
f'st.session_state has no attribute "{attr_name}". Did you forget to initialize it? '
|
||||
f"More info: https://docs.streamlit.io/develop/concepts/architecture/session-state#initialization"
|
||||
)
|
||||
@@ -0,0 +1,135 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from streamlit.runtime.state.common import (
|
||||
RegisterWidgetResult,
|
||||
T,
|
||||
ValueFieldName,
|
||||
WidgetArgs,
|
||||
WidgetCallback,
|
||||
WidgetDeserializer,
|
||||
WidgetKwargs,
|
||||
WidgetMetadata,
|
||||
WidgetSerializer,
|
||||
user_key_from_element_id,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.runtime.scriptrunner import ScriptRunContext
|
||||
|
||||
|
||||
def register_widget(
|
||||
element_id: str,
|
||||
*,
|
||||
deserializer: WidgetDeserializer[T],
|
||||
serializer: WidgetSerializer[T],
|
||||
ctx: ScriptRunContext | None,
|
||||
on_change_handler: WidgetCallback | None = None,
|
||||
args: WidgetArgs | None = None,
|
||||
kwargs: WidgetKwargs | None = None,
|
||||
value_type: ValueFieldName,
|
||||
) -> RegisterWidgetResult[T]:
|
||||
"""Register a widget with Streamlit, and return its current value.
|
||||
NOTE: This function should be called after the proto has been filled.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
element_id : str
|
||||
The id of the element. Must be unique.
|
||||
deserializer : WidgetDeserializer[T]
|
||||
Called to convert a widget's protobuf value to the value returned by
|
||||
its st.<widget_name> function.
|
||||
serializer : WidgetSerializer[T]
|
||||
Called to convert a widget's value to its protobuf representation.
|
||||
ctx : ScriptRunContext or None
|
||||
Used to ensure uniqueness of widget IDs, and to look up widget values.
|
||||
on_change_handler : WidgetCallback or None
|
||||
An optional callback invoked when the widget's value changes.
|
||||
args : WidgetArgs or None
|
||||
args to pass to on_change_handler when invoked
|
||||
kwargs : WidgetKwargs or None
|
||||
kwargs to pass to on_change_handler when invoked
|
||||
value_type: ValueType
|
||||
The value_type the widget is going to use.
|
||||
We use this information to start with a best-effort guess for the value_type
|
||||
of each widget. Once we actually receive a proto for a widget from the
|
||||
frontend, the guess is updated to be the correct type. Unfortunately, we're
|
||||
not able to always rely on the proto as the type may be needed earlier.
|
||||
Thankfully, in these cases (when value_type == "trigger_value"), the static
|
||||
table here being slightly inaccurate should never pose a problem.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
register_widget_result : RegisterWidgetResult[T]
|
||||
Provides information on which value to return to the widget caller,
|
||||
and whether the UI needs updating.
|
||||
|
||||
- Unhappy path:
|
||||
- Our ScriptRunContext doesn't exist (meaning that we're running
|
||||
as a "bare script" outside streamlit).
|
||||
- We are disconnected from the SessionState instance.
|
||||
In both cases we'll return a fallback RegisterWidgetResult[T].
|
||||
- Happy path:
|
||||
- The widget has already been registered on a previous run but the
|
||||
user hasn't interacted with it on the client. The widget will have
|
||||
the default value it was first created with. We then return a
|
||||
RegisterWidgetResult[T], containing this value.
|
||||
- The widget has already been registered and the user *has*
|
||||
interacted with it. The widget will have that most recent
|
||||
user-specified value. We then return a RegisterWidgetResult[T],
|
||||
containing this value.
|
||||
|
||||
For both paths a widget return value is provided, allowing the widgets
|
||||
to be used in a non-streamlit setting.
|
||||
"""
|
||||
# Create the widget's updated metadata, and register it with session_state.
|
||||
metadata = WidgetMetadata(
|
||||
element_id,
|
||||
deserializer,
|
||||
serializer,
|
||||
value_type=value_type,
|
||||
callback=on_change_handler,
|
||||
callback_args=args,
|
||||
callback_kwargs=kwargs,
|
||||
fragment_id=ctx.current_fragment_id if ctx else None,
|
||||
)
|
||||
return register_widget_from_metadata(metadata, ctx)
|
||||
|
||||
|
||||
def register_widget_from_metadata(
|
||||
metadata: WidgetMetadata[T],
|
||||
ctx: ScriptRunContext | None,
|
||||
) -> RegisterWidgetResult[T]:
|
||||
"""Register a widget and return its value, using an already constructed
|
||||
`WidgetMetadata`.
|
||||
|
||||
This is split out from `register_widget` to allow caching code to replay
|
||||
widgets by saving and reusing the completed metadata.
|
||||
|
||||
See `register_widget` for details on what this returns.
|
||||
"""
|
||||
if ctx is None:
|
||||
# Early-out if we don't have a script run context (which probably means
|
||||
# we're running as a "bare" Python script, and not via `streamlit run`).
|
||||
return RegisterWidgetResult.failure(deserializer=metadata.deserializer)
|
||||
|
||||
widget_id = metadata.id
|
||||
user_key = user_key_from_element_id(widget_id)
|
||||
|
||||
return ctx.session_state.register_widget(metadata, user_key)
|
||||
109
myenv/lib/python3.11/site-packages/streamlit/runtime/stats.py
Normal file
109
myenv/lib/python3.11/site-packages/streamlit/runtime/stats.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.proto.openmetrics_data_model_pb2 import Metric as MetricProto
|
||||
|
||||
|
||||
class CacheStat(NamedTuple):
|
||||
"""Describes a single cache entry.
|
||||
|
||||
Properties
|
||||
----------
|
||||
category_name : str
|
||||
A human-readable name for the cache "category" that the entry belongs
|
||||
to - e.g. "st.memo", "session_state", etc.
|
||||
cache_name : str
|
||||
A human-readable name for cache instance that the entry belongs to.
|
||||
For "st.memo" and other function decorator caches, this might be the
|
||||
name of the cached function. If the cache category doesn't have
|
||||
multiple separate cache instances, this can just be the empty string.
|
||||
byte_length : int
|
||||
The entry's memory footprint in bytes.
|
||||
"""
|
||||
|
||||
category_name: str
|
||||
cache_name: str
|
||||
byte_length: int
|
||||
|
||||
def to_metric_str(self) -> str:
|
||||
return f'cache_memory_bytes{{cache_type="{self.category_name}",cache="{self.cache_name}"}} {self.byte_length}'
|
||||
|
||||
def marshall_metric_proto(self, metric: MetricProto) -> None:
|
||||
"""Fill an OpenMetrics `Metric` protobuf object."""
|
||||
label = metric.labels.add()
|
||||
label.name = "cache_type"
|
||||
label.value = self.category_name
|
||||
|
||||
label = metric.labels.add()
|
||||
label.name = "cache"
|
||||
label.value = self.cache_name
|
||||
|
||||
metric_point = metric.metric_points.add()
|
||||
metric_point.gauge_value.int_value = self.byte_length
|
||||
|
||||
|
||||
def group_stats(stats: list[CacheStat]) -> list[CacheStat]:
|
||||
"""Group a list of CacheStats by category_name and cache_name and sum byte_length."""
|
||||
|
||||
def key_function(individual_stat):
|
||||
return individual_stat.category_name, individual_stat.cache_name
|
||||
|
||||
result: list[CacheStat] = []
|
||||
|
||||
sorted_stats = sorted(stats, key=key_function)
|
||||
grouped_stats = itertools.groupby(sorted_stats, key=key_function)
|
||||
|
||||
for (category_name, cache_name), single_group_stats in grouped_stats:
|
||||
result.append(
|
||||
CacheStat(
|
||||
category_name=category_name,
|
||||
cache_name=cache_name,
|
||||
byte_length=sum(item.byte_length for item in single_group_stats),
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CacheStatsProvider(Protocol):
|
||||
@abstractmethod
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StatsManager:
|
||||
def __init__(self):
|
||||
self._cache_stats_providers: list[CacheStatsProvider] = []
|
||||
|
||||
def register_provider(self, provider: CacheStatsProvider) -> None:
|
||||
"""Register a CacheStatsProvider with the manager.
|
||||
This function is not thread-safe. Call it immediately after
|
||||
creation.
|
||||
"""
|
||||
self._cache_stats_providers.append(provider)
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
"""Return a list containing all stats from each registered provider."""
|
||||
all_stats: list[CacheStat] = []
|
||||
for provider in self._cache_stats_providers:
|
||||
all_stats.extend(provider.get_stats())
|
||||
|
||||
return all_stats
|
||||
@@ -0,0 +1,149 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, NamedTuple, Protocol
|
||||
|
||||
from streamlit import util
|
||||
from streamlit.runtime.stats import CacheStatsProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from streamlit.proto.Common_pb2 import FileURLs as FileURLsProto
|
||||
|
||||
|
||||
class UploadedFileRec(NamedTuple):
|
||||
"""Metadata and raw bytes for an uploaded file. Immutable."""
|
||||
|
||||
file_id: str
|
||||
name: str
|
||||
type: str
|
||||
data: bytes
|
||||
|
||||
|
||||
class UploadFileUrlInfo(NamedTuple):
|
||||
"""Information we provide for single file in get_upload_urls."""
|
||||
|
||||
file_id: str
|
||||
upload_url: str
|
||||
delete_url: str
|
||||
|
||||
|
||||
class DeletedFile(NamedTuple):
|
||||
"""Represents a deleted file in deserialized values for st.file_uploader and
|
||||
st.camera_input.
|
||||
|
||||
Return this from st.file_uploader and st.camera_input deserialize (so they can
|
||||
be used in session_state), when widget value contains file record that is missing
|
||||
from the storage.
|
||||
DeleteFile instances filtered out before return final value to the user in script,
|
||||
or before sending to frontend.
|
||||
"""
|
||||
|
||||
file_id: str
|
||||
|
||||
|
||||
class UploadedFile(io.BytesIO):
|
||||
"""A mutable uploaded file.
|
||||
|
||||
This class extends BytesIO, which has copy-on-write semantics when
|
||||
initialized with `bytes`.
|
||||
"""
|
||||
|
||||
def __init__(self, record: UploadedFileRec, file_urls: FileURLsProto):
|
||||
# BytesIO's copy-on-write semantics doesn't seem to be mentioned in
|
||||
# the Python docs - possibly because it's a CPython-only optimization
|
||||
# and not guaranteed to be in other Python runtimes. But it's detailed
|
||||
# here: https://hg.python.org/cpython/rev/79a5fbe2c78f
|
||||
super().__init__(record.data)
|
||||
self.file_id = record.file_id
|
||||
self.name = record.name
|
||||
self.type = record.type
|
||||
self.size = len(record.data)
|
||||
self._file_urls = file_urls
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, UploadedFile):
|
||||
return NotImplemented
|
||||
return self.file_id == other.file_id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
|
||||
class UploadedFileManager(CacheStatsProvider, Protocol):
|
||||
"""UploadedFileManager protocol, that should be implemented by the concrete
|
||||
uploaded file managers.
|
||||
|
||||
It is responsible for:
|
||||
- retrieving files by session_id and file_id for st.file_uploader and
|
||||
st.camera_input
|
||||
- cleaning up uploaded files associated with session on session end
|
||||
|
||||
It should be created during Runtime initialization.
|
||||
|
||||
Optionally UploadedFileManager could be responsible for issuing URLs which will be
|
||||
used by frontend to upload files to.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_files(
|
||||
self, session_id: str, file_ids: Sequence[str]
|
||||
) -> list[UploadedFileRec]:
|
||||
"""Return a list of UploadedFileRec for a given sequence of file_ids.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The ID of the session that owns the files.
|
||||
file_ids
|
||||
The sequence of ids associated with files to retrieve.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[UploadedFileRec]
|
||||
A list of URL UploadedFileRec instances, each instance contains information
|
||||
about uploaded file.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def remove_session_files(self, session_id: str) -> None:
|
||||
"""Remove all files associated with a given session."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_upload_urls(
|
||||
self, session_id: str, file_names: Sequence[str]
|
||||
) -> list[UploadFileUrlInfo]:
|
||||
"""Return a list of UploadFileUrlInfo for a given sequence of file_names.
|
||||
Optional to implement, issuing of URLs could be done by other service.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The ID of the session that request URLs.
|
||||
file_names
|
||||
The sequence of file names for which URLs are requested
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[UploadFileUrlInfo]
|
||||
A list of UploadFileUrlInfo instances, each instance contains information
|
||||
about uploaded file URLs.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,167 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, Final, cast
|
||||
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.app_session import AppSession
|
||||
from streamlit.runtime.session_manager import (
|
||||
ActiveSessionInfo,
|
||||
SessionClient,
|
||||
SessionInfo,
|
||||
SessionManager,
|
||||
SessionStorage,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.runtime.script_data import ScriptData
|
||||
from streamlit.runtime.scriptrunner.script_cache import ScriptCache
|
||||
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
|
||||
class WebsocketSessionManager(SessionManager):
|
||||
"""A SessionManager used to manage sessions with lifecycles tied to those of a
|
||||
browser tab's websocket connection.
|
||||
|
||||
WebsocketSessionManagers differentiate between "active" and "inactive" sessions.
|
||||
Active sessions are those with a currently active websocket connection. Inactive
|
||||
sessions are sessions without. Eventual cleanup of inactive sessions is a detail left
|
||||
to the specific SessionStorage that a WebsocketSessionManager is instantiated with.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_storage: SessionStorage,
|
||||
uploaded_file_manager: UploadedFileManager,
|
||||
script_cache: ScriptCache,
|
||||
message_enqueued_callback: Callable[[], None] | None,
|
||||
) -> None:
|
||||
self._session_storage = session_storage
|
||||
self._uploaded_file_mgr = uploaded_file_manager
|
||||
self._script_cache = script_cache
|
||||
self._message_enqueued_callback = message_enqueued_callback
|
||||
|
||||
# Mapping of AppSession.id -> ActiveSessionInfo.
|
||||
self._active_session_info_by_id: dict[str, ActiveSessionInfo] = {}
|
||||
|
||||
def connect_session(
|
||||
self,
|
||||
client: SessionClient,
|
||||
script_data: ScriptData,
|
||||
user_info: dict[str, str | bool | None],
|
||||
existing_session_id: str | None = None,
|
||||
session_id_override: str | None = None,
|
||||
) -> str:
|
||||
assert not (existing_session_id and session_id_override), (
|
||||
"Only one of existing_session_id and session_id_override should be truthy"
|
||||
)
|
||||
|
||||
if existing_session_id in self._active_session_info_by_id:
|
||||
_LOGGER.warning(
|
||||
"Session with id %s is already connected! Connecting to a new session.",
|
||||
existing_session_id,
|
||||
)
|
||||
|
||||
session_info = (
|
||||
existing_session_id
|
||||
and existing_session_id not in self._active_session_info_by_id
|
||||
and self._session_storage.get(existing_session_id)
|
||||
)
|
||||
|
||||
if session_info:
|
||||
existing_session = session_info.session
|
||||
existing_session.register_file_watchers()
|
||||
|
||||
self._active_session_info_by_id[existing_session.id] = ActiveSessionInfo(
|
||||
client,
|
||||
existing_session,
|
||||
session_info.script_run_count,
|
||||
)
|
||||
self._session_storage.delete(existing_session.id)
|
||||
|
||||
return existing_session.id
|
||||
|
||||
session = AppSession(
|
||||
script_data=script_data,
|
||||
uploaded_file_manager=self._uploaded_file_mgr,
|
||||
script_cache=self._script_cache,
|
||||
message_enqueued_callback=self._message_enqueued_callback,
|
||||
user_info=user_info,
|
||||
session_id_override=session_id_override,
|
||||
)
|
||||
|
||||
_LOGGER.debug(
|
||||
"Created new session for client %s. Session ID: %s", id(client), session.id
|
||||
)
|
||||
|
||||
assert session.id not in self._active_session_info_by_id, (
|
||||
f"session.id '{session.id}' registered multiple times!"
|
||||
)
|
||||
|
||||
self._active_session_info_by_id[session.id] = ActiveSessionInfo(client, session)
|
||||
return session.id
|
||||
|
||||
def disconnect_session(self, session_id: str) -> None:
|
||||
if session_id in self._active_session_info_by_id:
|
||||
active_session_info = self._active_session_info_by_id[session_id]
|
||||
session = active_session_info.session
|
||||
|
||||
session.request_script_stop()
|
||||
session.disconnect_file_watchers()
|
||||
|
||||
self._session_storage.save(
|
||||
SessionInfo(
|
||||
client=None,
|
||||
session=session,
|
||||
script_run_count=active_session_info.script_run_count,
|
||||
)
|
||||
)
|
||||
del self._active_session_info_by_id[session_id]
|
||||
|
||||
def get_active_session_info(self, session_id: str) -> ActiveSessionInfo | None:
|
||||
return self._active_session_info_by_id.get(session_id)
|
||||
|
||||
def is_active_session(self, session_id: str) -> bool:
|
||||
return session_id in self._active_session_info_by_id
|
||||
|
||||
def list_active_sessions(self) -> list[ActiveSessionInfo]:
|
||||
return list(self._active_session_info_by_id.values())
|
||||
|
||||
def close_session(self, session_id: str) -> None:
|
||||
if session_id in self._active_session_info_by_id:
|
||||
active_session_info = self._active_session_info_by_id[session_id]
|
||||
del self._active_session_info_by_id[session_id]
|
||||
active_session_info.session.shutdown()
|
||||
return
|
||||
|
||||
session_info = self._session_storage.get(session_id)
|
||||
if session_info:
|
||||
self._session_storage.delete(session_id)
|
||||
session_info.session.shutdown()
|
||||
|
||||
def get_session_info(self, session_id: str) -> SessionInfo | None:
|
||||
session_info = self.get_active_session_info(session_id)
|
||||
if session_info:
|
||||
return cast("SessionInfo", session_info)
|
||||
return self._session_storage.get(session_id)
|
||||
|
||||
def list_sessions(self) -> list[SessionInfo]:
|
||||
return (
|
||||
cast("list[SessionInfo]", self.list_active_sessions())
|
||||
+ self._session_storage.list()
|
||||
)
|
||||
Reference in New Issue
Block a user