Mise à jour de Monitor.py et autres scripts

This commit is contained in:
Debian
2025-07-23 10:46:27 +02:00
parent 7081418ce0
commit 7de3e0fb50
8604 changed files with 2789953 additions and 295 deletions

View File

@@ -0,0 +1,41 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from streamlit.runtime.state.common import WidgetArgs, WidgetCallback, WidgetKwargs
from streamlit.runtime.state.query_params_proxy import QueryParamsProxy
from streamlit.runtime.state.safe_session_state import SafeSessionState
from streamlit.runtime.state.session_state import (
SCRIPT_RUN_WITHOUT_ERRORS_KEY,
SessionState,
SessionStateStatProvider,
)
from streamlit.runtime.state.session_state_proxy import (
SessionStateProxy,
get_session_state,
)
from streamlit.runtime.state.widgets import register_widget
__all__ = [
"WidgetArgs",
"WidgetCallback",
"WidgetKwargs",
"QueryParamsProxy",
"SafeSessionState",
"SCRIPT_RUN_WITHOUT_ERRORS_KEY",
"SessionState",
"SessionStateStatProvider",
"SessionStateProxy",
"get_session_state",
"register_widget",
]

View File

@@ -0,0 +1,191 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions and data structures shared by session_state.py and widgets.py."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import (
Any,
Callable,
Final,
Generic,
Literal,
TypeVar,
cast,
get_args,
)
from typing_extensions import TypeAlias, TypeGuard
from streamlit import util
from streamlit.errors import (
StreamlitAPIException,
)
GENERATED_ELEMENT_ID_PREFIX: Final = "$$ID"
TESTING_KEY = "$$STREAMLIT_INTERNAL_KEY_TESTING"
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
WidgetArgs: TypeAlias = tuple[Any, ...]
WidgetKwargs: TypeAlias = dict[str, Any]
WidgetCallback: TypeAlias = Callable[..., None]
# A deserializer receives the value from whatever field is set on the
# WidgetState proto, and returns a regular python value. A serializer
# receives a regular python value, and returns something suitable for
# a value field on WidgetState proto. They should be inverses.
WidgetDeserializer: TypeAlias = Callable[[Any, str], T]
WidgetSerializer: TypeAlias = Callable[[T], Any]
# The array value field names are part of the larger set of possible value
# field names. See the explanation for said set below. The message types
# associated with these fields are distinguished by storing data in a `data`
# field in their messages, meaning they need special treatment in certain
# circumstances. Hence, they need their own, dedicated, sub-type.
ArrayValueFieldName: TypeAlias = Literal[
"double_array_value",
"int_array_value",
"string_array_value",
]
# A frozenset containing the allowed values of the ArrayValueFieldName type.
# Useful for membership checking.
_ARRAY_VALUE_FIELD_NAMES: Final = frozenset(
cast(
"tuple[ArrayValueFieldName, ...]",
# NOTE: get_args is not recursive, so this only works as long as
# ArrayValueFieldName remains flat.
get_args(ArrayValueFieldName),
)
)
# These are the possible field names that can be set in the `value` oneof-field
# of the WidgetState message (schema found in .proto/WidgetStates.proto).
# We need these as a literal type to ensure correspondence with the protobuf
# schema in certain parts of the python code.
# TODO(harahu): It would be preferable if this type was automatically derived
# from the protobuf schema, rather than manually maintained. Not sure how to
# achieve that, though.
ValueFieldName: TypeAlias = Literal[
ArrayValueFieldName,
"arrow_value",
"bool_value",
"bytes_value",
"double_value",
"file_uploader_state_value",
"int_value",
"json_value",
"string_value",
"trigger_value",
"string_trigger_value",
"chat_input_value",
]
def is_array_value_field_name(obj: object) -> TypeGuard[ArrayValueFieldName]:
return obj in _ARRAY_VALUE_FIELD_NAMES
@dataclass(frozen=True)
class WidgetMetadata(Generic[T]):
"""Metadata associated with a single widget. Immutable."""
id: str
deserializer: WidgetDeserializer[T] = field(repr=False)
serializer: WidgetSerializer[T] = field(repr=False)
value_type: ValueFieldName
# An optional user-code callback invoked when the widget's value changes.
# Widget callbacks are called at the start of a script run, before the
# body of the script is executed.
callback: WidgetCallback | None = None
callback_args: WidgetArgs | None = None
callback_kwargs: WidgetKwargs | None = None
fragment_id: str | None = None
def __repr__(self) -> str:
return util.repr_(self)
@dataclass(frozen=True)
class RegisterWidgetResult(Generic[T_co]):
"""Result returned by the `register_widget` family of functions/methods.
Should be usable by widget code to determine what value to return, and
whether to update the UI.
Parameters
----------
value : T_co
The widget's current value, or, in cases where the true widget value
could not be determined, an appropriate fallback value.
This value should be returned by the widget call.
value_changed : bool
True if the widget's value is different from the value most recently
returned from the frontend.
Implies an update to the frontend is needed.
"""
value: T_co
value_changed: bool
@classmethod
def failure(
cls, deserializer: WidgetDeserializer[T_co]
) -> RegisterWidgetResult[T_co]:
"""The canonical way to construct a RegisterWidgetResult in cases
where the true widget value could not be determined.
"""
return cls(value=deserializer(None, ""), value_changed=False)
def user_key_from_element_id(element_id: str) -> str | None:
"""Return the user key portion of a element id, or None if the id does not
have a user key.
TODO This will incorrectly indicate no user key if the user actually provides
"None" as a key, but we can't avoid this kind of problem while storing the
string representation of the no-user-key sentinel as part of the element id.
"""
user_key: str | None = element_id.split("-", maxsplit=2)[-1]
return None if user_key == "None" else user_key
def is_element_id(key: str) -> bool:
"""True if the given session_state key has the structure of a element ID."""
return key.startswith(GENERATED_ELEMENT_ID_PREFIX)
def is_keyed_element_id(key: str) -> bool:
"""True if the given session_state key has the structure of a element ID
with a user_key.
"""
return is_element_id(key) and not key.endswith("-None")
def require_valid_user_key(key: str) -> None:
"""Raise an Exception if the given user_key is invalid."""
if is_element_id(key):
raise StreamlitAPIException(
f"Keys beginning with {GENERATED_ELEMENT_ID_PREFIX} are reserved."
)

View File

@@ -0,0 +1,205 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Iterable, Iterator, MutableMapping
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Final
from urllib import parse
from streamlit.errors import StreamlitAPIException
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx
if TYPE_CHECKING:
from _typeshed import SupportsKeysAndGetItem
EMBED_QUERY_PARAM: Final[str] = "embed"
EMBED_OPTIONS_QUERY_PARAM: Final[str] = "embed_options"
EMBED_QUERY_PARAMS_KEYS: Final[list[str]] = [
EMBED_QUERY_PARAM,
EMBED_OPTIONS_QUERY_PARAM,
]
@dataclass
class QueryParams(MutableMapping[str, str]):
"""A lightweight wrapper of a dict that sends forwardMsgs when state changes.
It stores str keys with str and List[str] values.
"""
_query_params: dict[str, list[str] | str] = field(default_factory=dict)
def __iter__(self) -> Iterator[str]:
self._ensure_single_query_api_used()
return iter(
key
for key in self._query_params.keys()
if key not in EMBED_QUERY_PARAMS_KEYS
)
def __getitem__(self, key: str) -> str:
"""Retrieves a value for a given key in query parameters.
Returns the last item in a list or an empty string if empty.
If the key is not present, raise KeyError.
"""
self._ensure_single_query_api_used()
try:
if key in EMBED_QUERY_PARAMS_KEYS:
raise KeyError(missing_key_error_message(key))
value = self._query_params[key]
if isinstance(value, list):
if len(value) == 0:
return ""
else:
# Return the last value to mimic Tornado's behavior
# https://www.tornadoweb.org/en/stable/web.html#tornado.web.RequestHandler.get_query_argument
return value[-1]
return value
except KeyError:
raise KeyError(missing_key_error_message(key))
def __setitem__(self, key: str, value: str | Iterable[str]) -> None:
self._ensure_single_query_api_used()
self.__set_item_internal(key, value)
self._send_query_param_msg()
def __set_item_internal(self, key: str, value: str | Iterable[str]) -> None:
if isinstance(value, dict):
raise StreamlitAPIException(
f"You cannot set a query params key `{key}` to a dictionary."
)
if key in EMBED_QUERY_PARAMS_KEYS:
raise StreamlitAPIException(
"Query param embed and embed_options (case-insensitive) cannot be set programmatically."
)
# Type checking users should handle the string serialization themselves
# We will accept any type for the list and serialize to str just in case
if isinstance(value, Iterable) and not isinstance(value, str):
self._query_params[key] = [str(item) for item in value]
else:
self._query_params[key] = str(value)
def __delitem__(self, key: str) -> None:
self._ensure_single_query_api_used()
try:
if key in EMBED_QUERY_PARAMS_KEYS:
raise KeyError(missing_key_error_message(key))
del self._query_params[key]
self._send_query_param_msg()
except KeyError:
raise KeyError(missing_key_error_message(key))
def update(
self,
other: Iterable[tuple[str, str | Iterable[str]]]
| SupportsKeysAndGetItem[str, str | Iterable[str]] = (),
/,
**kwds: str,
):
# This overrides the `update` provided by MutableMapping
# to ensure only one one ForwardMsg is sent.
self._ensure_single_query_api_used()
if hasattr(other, "keys") and hasattr(other, "__getitem__"):
for key in other.keys():
self.__set_item_internal(key, other[key])
else:
for key, value in other:
self.__set_item_internal(key, value)
for key, value in kwds.items():
self.__set_item_internal(key, value)
self._send_query_param_msg()
def get_all(self, key: str) -> list[str]:
self._ensure_single_query_api_used()
if key not in self._query_params or key in EMBED_QUERY_PARAMS_KEYS:
return []
value = self._query_params[key]
return value if isinstance(value, list) else [value]
def __len__(self) -> int:
self._ensure_single_query_api_used()
return len(
{key for key in self._query_params if key not in EMBED_QUERY_PARAMS_KEYS}
)
def __str__(self) -> str:
self._ensure_single_query_api_used()
return str(self._query_params)
def _send_query_param_msg(self) -> None:
ctx = get_script_run_ctx()
if ctx is None:
return
self._ensure_single_query_api_used()
msg = ForwardMsg()
msg.page_info_changed.query_string = parse.urlencode(
self._query_params, doseq=True
)
ctx.query_string = msg.page_info_changed.query_string
ctx.enqueue(msg)
def clear(self) -> None:
self._ensure_single_query_api_used()
self.clear_with_no_forward_msg(preserve_embed=True)
self._send_query_param_msg()
def to_dict(self) -> dict[str, str]:
self._ensure_single_query_api_used()
# return the last query param if multiple values are set
return {
key: self[key]
for key in self._query_params
if key not in EMBED_QUERY_PARAMS_KEYS
}
def from_dict(
self,
_dict: Iterable[tuple[str, str | Iterable[str]]]
| SupportsKeysAndGetItem[str, str | Iterable[str]],
):
self._ensure_single_query_api_used()
old_value = self._query_params.copy()
self.clear_with_no_forward_msg(preserve_embed=True)
try:
self.update(_dict)
except StreamlitAPIException:
# restore the original from before we made any changes.
self._query_params = old_value
raise
def set_with_no_forward_msg(self, key: str, val: list[str] | str) -> None:
self._query_params[key] = val
def clear_with_no_forward_msg(self, preserve_embed: bool = False) -> None:
self._query_params = {
key: value
for key, value in self._query_params.items()
if key in EMBED_QUERY_PARAMS_KEYS and preserve_embed
}
def _ensure_single_query_api_used(self):
ctx = get_script_run_ctx()
if ctx is None:
return
ctx.mark_production_query_params_used()
def missing_key_error_message(key: str) -> str:
return f'st.query_params has no key "{key}".'

View File

@@ -0,0 +1,218 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Iterable, Iterator, MutableMapping
from typing import TYPE_CHECKING, overload
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.state.session_state_proxy import get_session_state
if TYPE_CHECKING:
from _typeshed import SupportsKeysAndGetItem
class QueryParamsProxy(MutableMapping[str, str]):
"""
A stateless singleton that proxies ``st.query_params`` interactions
to the current script thread's QueryParams instance.
"""
def __iter__(self) -> Iterator[str]:
with get_session_state().query_params() as qp:
return iter(qp)
def __len__(self) -> int:
with get_session_state().query_params() as qp:
return len(qp)
def __str__(self) -> str:
with get_session_state().query_params() as qp:
return str(qp)
@gather_metrics("query_params.get_item")
def __getitem__(self, key: str) -> str:
with get_session_state().query_params() as qp:
try:
return qp[key]
except KeyError:
raise KeyError(self.missing_key_error_message(key))
def __delitem__(self, key: str) -> None:
with get_session_state().query_params() as qp:
del qp[key]
@gather_metrics("query_params.set_item")
def __setitem__(self, key: str, value: str | Iterable[str]) -> None:
with get_session_state().query_params() as qp:
qp[key] = value
@gather_metrics("query_params.get_attr")
def __getattr__(self, key: str) -> str:
with get_session_state().query_params() as qp:
try:
return qp[key]
except KeyError:
raise AttributeError(self.missing_attr_error_message(key))
def __delattr__(self, key: str) -> None:
with get_session_state().query_params() as qp:
try:
del qp[key]
except KeyError:
raise AttributeError(self.missing_key_error_message(key))
@overload
def update(
self, mapping: SupportsKeysAndGetItem[str, str | Iterable[str]], /, **kwds: str
) -> None: ...
@overload
def update(
self, keys_and_values: Iterable[tuple[str, str | Iterable[str]]], /, **kwds: str
) -> None: ...
@overload
def update(self, **kwds: str | Iterable[str]) -> None: ...
def update(self, other=(), /, **kwds):
"""
Update one or more values in query_params at once from a dictionary or
dictionary-like object.
See `Mapping.update()` from Python's `collections` library.
Parameters
----------
other: SupportsKeysAndGetItem[str, str] | Iterable[tuple[str, str]]
A dictionary or mapping of strings to strings.
**kwds: str
Additional key/value pairs to update passed as keyword arguments.
"""
with get_session_state().query_params() as qp:
qp.update(other, **kwds)
@gather_metrics("query_params.set_attr")
def __setattr__(self, key: str, value: str | Iterable[str]) -> None:
with get_session_state().query_params() as qp:
qp[key] = value
@gather_metrics("query_params.get_all")
def get_all(self, key: str) -> list[str]:
"""
Get a list of all query parameter values associated to a given key.
When a key is repeated as a query parameter within the URL, this method
allows all values to be obtained. In contrast, dict-like methods only
retrieve the last value when a key is repeated in the URL.
Parameters
----------
key: str
The label of the query parameter in the URL.
Returns
-------
List[str]
A list of values associated to the given key. May return zero, one,
or multiple values.
"""
with get_session_state().query_params() as qp:
return qp.get_all(key)
@gather_metrics("query_params.clear")
def clear(self) -> None:
"""
Clear all query parameters from the URL of the app.
Returns
-------
None
"""
with get_session_state().query_params() as qp:
qp.clear()
@gather_metrics("query_params.to_dict")
def to_dict(self) -> dict[str, str]:
"""
Get all query parameters as a dictionary.
This method primarily exists for internal use and is not needed for
most cases. ``st.query_params`` returns an object that inherits from
``dict`` by default.
When a key is repeated as a query parameter within the URL, this method
will return only the last value of each unique key.
Returns
-------
Dict[str,str]
A dictionary of the current query paramters in the app's URL.
"""
with get_session_state().query_params() as qp:
return qp.to_dict()
@overload
def from_dict(
self, keys_and_values: Iterable[tuple[str, str | Iterable[str]]]
) -> None: ...
@overload
def from_dict(
self, mapping: SupportsKeysAndGetItem[str, str | Iterable[str]]
) -> None: ...
@gather_metrics("query_params.from_dict")
def from_dict(self, params):
"""
Set all of the query parameters from a dictionary or dictionary-like object.
This method primarily exists for advanced users who want to control
multiple query parameters in a single update. To set individual query
parameters, use key or attribute notation instead.
This method inherits limitations from ``st.query_params`` and can't be
used to set embedding options as described in `Embed your app \
<https://docs.streamlit.io/deploy/streamlit-community-cloud/share-your-app/embed-your-app#embed-options>`_.
To handle repeated keys, the value in a key-value pair should be a list.
.. note::
``.from_dict()`` is not a direct inverse of ``.to_dict()`` if
you are working with repeated keys. A true inverse operation is
``{key: st.query_params.get_all(key) for key in st.query_params}``.
Parameters
----------
params: dict
A dictionary used to replace the current query parameters.
Example
-------
>>> import streamlit as st
>>>
>>> st.query_params.from_dict({"foo": "bar", "baz": [1, "two"]})
"""
with get_session_state().query_params() as qp:
return qp.from_dict(params)
@staticmethod
def missing_key_error_message(key: str) -> str:
return f'st.query_params has no key "{key}".'
@staticmethod
def missing_attr_error_message(key: str) -> str:
return f'st.query_params has no attribute "{key}".'

View File

@@ -0,0 +1,138 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import threading
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable
if TYPE_CHECKING:
from collections.abc import Iterator
from streamlit.proto.WidgetStates_pb2 import WidgetState as WidgetStateProto
from streamlit.proto.WidgetStates_pb2 import WidgetStates as WidgetStatesProto
from streamlit.runtime.state.common import RegisterWidgetResult, T, WidgetMetadata
from streamlit.runtime.state.query_params import QueryParams
from streamlit.runtime.state.session_state import SessionState
class SafeSessionState:
"""Thread-safe wrapper around SessionState.
When AppSession gets a re-run request, it can interrupt its existing
ScriptRunner and spin up a new ScriptRunner to handle the request.
When this happens, the existing ScriptRunner will continue executing
its script until it reaches a yield point - but during this time, it
must not mutate its SessionState.
"""
_state: SessionState
_lock: threading.RLock
_yield_callback: Callable[[], None]
def __init__(self, state: SessionState, yield_callback: Callable[[], None]):
# Fields must be set using the object's setattr method to avoid
# infinite recursion from trying to look up the fields we're setting.
object.__setattr__(self, "_state", state)
# TODO: we'd prefer this be a threading.Lock instead of RLock -
# but `call_callbacks` first needs to be rewritten.
object.__setattr__(self, "_lock", threading.RLock())
object.__setattr__(self, "_yield_callback", yield_callback)
def register_widget(
self, metadata: WidgetMetadata[T], user_key: str | None
) -> RegisterWidgetResult[T]:
self._yield_callback()
with self._lock:
return self._state.register_widget(metadata, user_key)
def on_script_will_rerun(self, latest_widget_states: WidgetStatesProto) -> None:
self._yield_callback()
with self._lock:
# TODO: rewrite this to copy the callbacks list into a local
# variable so that we don't need to hold our lock for the
# duration. (This will also allow us to downgrade our RLock
# to a Lock.)
self._state.on_script_will_rerun(latest_widget_states)
def on_script_finished(self, widget_ids_this_run: set[str]) -> None:
with self._lock:
self._state.on_script_finished(widget_ids_this_run)
def maybe_check_serializable(self) -> None:
with self._lock:
self._state.maybe_check_serializable()
def get_widget_states(self) -> list[WidgetStateProto]:
"""Return a list of serialized widget values for each widget with a value."""
with self._lock:
return self._state.get_widget_states()
def is_new_state_value(self, user_key: str) -> bool:
with self._lock:
return self._state.is_new_state_value(user_key)
@property
def filtered_state(self) -> dict[str, Any]:
"""The combined session and widget state, excluding keyless widgets."""
with self._lock:
return self._state.filtered_state
def __getitem__(self, key: str) -> Any:
self._yield_callback()
with self._lock:
return self._state[key]
def __setitem__(self, key: str, value: Any) -> None:
self._yield_callback()
with self._lock:
self._state[key] = value
def __delitem__(self, key: str) -> None:
self._yield_callback()
with self._lock:
del self._state[key]
def __contains__(self, key: str) -> bool:
self._yield_callback()
with self._lock:
return key in self._state
def __getattr__(self, key: str) -> Any:
try:
return self[key]
except KeyError:
raise AttributeError(f"{key} not found in session_state.")
def __setattr__(self, key: str, value: Any) -> None:
self[key] = value
def __delattr__(self, key: str) -> None:
try:
del self[key]
except KeyError:
raise AttributeError(f"{key} not found in session_state.")
def __repr__(self):
"""Presents itself as a simple dict of the underlying SessionState instance."""
kv = ((k, self._state[k]) for k in self._state._keys())
s = ", ".join(f"{k}: {v!r}" for k, v in kv)
return f"{{{s}}}"
@contextmanager
def query_params(self) -> Iterator[QueryParams]:
self._yield_callback()
with self._lock:
yield self._state.query_params

View File

@@ -0,0 +1,773 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
import pickle
from collections.abc import Iterator, KeysView, MutableMapping
from copy import deepcopy
from dataclasses import dataclass, field, replace
from typing import (
TYPE_CHECKING,
Any,
Final,
Union,
cast,
)
from typing_extensions import TypeAlias
import streamlit as st
from streamlit import config, util
from streamlit.errors import StreamlitAPIException, UnserializableSessionStateError
from streamlit.proto.WidgetStates_pb2 import WidgetState as WidgetStateProto
from streamlit.proto.WidgetStates_pb2 import WidgetStates as WidgetStatesProto
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx
from streamlit.runtime.state.common import (
RegisterWidgetResult,
T,
ValueFieldName,
WidgetMetadata,
is_array_value_field_name,
is_element_id,
is_keyed_element_id,
)
from streamlit.runtime.state.query_params import QueryParams
from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats
if TYPE_CHECKING:
from streamlit.runtime.session_manager import SessionManager
STREAMLIT_INTERNAL_KEY_PREFIX: Final = "$$STREAMLIT_INTERNAL_KEY"
SCRIPT_RUN_WITHOUT_ERRORS_KEY: Final = (
f"{STREAMLIT_INTERNAL_KEY_PREFIX}_SCRIPT_RUN_WITHOUT_ERRORS"
)
@dataclass(frozen=True)
class Serialized:
"""A widget value that's serialized to a protobuf. Immutable."""
value: WidgetStateProto
@dataclass(frozen=True)
class Value:
"""A widget value that's not serialized. Immutable."""
value: Any
WState: TypeAlias = Union[Value, Serialized]
@dataclass
class WStates(MutableMapping[str, Any]):
"""A mapping of widget IDs to values. Widget values can be stored in
serialized or deserialized form, but when values are retrieved from the
mapping, they'll always be deserialized.
"""
states: dict[str, WState] = field(default_factory=dict)
widget_metadata: dict[str, WidgetMetadata[Any]] = field(default_factory=dict)
def __repr__(self):
return util.repr_(self)
def __getitem__(self, k: str) -> Any:
"""Return the value of the widget with the given key.
If the widget's value is currently stored in serialized form, it
will be deserialized first.
"""
wstate = self.states.get(k)
if wstate is None:
raise KeyError(k)
if isinstance(wstate, Value):
# The widget's value is already deserialized - return it directly.
return wstate.value
# The widget's value is serialized. We deserialize it, and return
# the deserialized value.
metadata = self.widget_metadata.get(k)
if metadata is None:
# No deserializer, which should only happen if state is
# gotten from a reconnecting browser and the script is
# trying to access it. Pretend it doesn't exist.
raise KeyError(k)
value_field_name = cast(
"ValueFieldName",
wstate.value.WhichOneof("value"),
)
value = (
wstate.value.__getattribute__(value_field_name)
if value_field_name # Field name is None if the widget value was cleared
else None
)
if is_array_value_field_name(value_field_name):
# Array types are messages with data in a `data` field
value = value.data
elif value_field_name == "json_value":
value = json.loads(value)
deserialized = metadata.deserializer(value, metadata.id)
# Update metadata to reflect information from WidgetState proto
self.set_widget_metadata(
replace(
metadata,
value_type=value_field_name,
)
)
self.states[k] = Value(deserialized)
return deserialized
def __setitem__(self, k: str, v: WState) -> None:
self.states[k] = v
def __delitem__(self, k: str) -> None:
del self.states[k]
def __len__(self) -> int:
return len(self.states)
def __iter__(self):
# For this and many other methods, we can't simply delegate to the
# states field, because we need to invoke `__getitem__` for any
# values, to handle deserialization and unwrapping of values.
yield from self.states
def keys(self) -> KeysView[str]:
return KeysView(self.states)
def items(self) -> set[tuple[str, Any]]: # type: ignore[override]
return {(k, self[k]) for k in self}
def values(self) -> set[Any]: # type: ignore[override]
return {self[wid] for wid in self}
def update(self, other: WStates) -> None: # type: ignore[override]
"""Copy all widget values and metadata from 'other' into this mapping,
overwriting any data in this mapping that's also present in 'other'.
"""
self.states.update(other.states)
self.widget_metadata.update(other.widget_metadata)
def set_widget_from_proto(self, widget_state: WidgetStateProto) -> None:
"""Set a widget's serialized value, overwriting any existing value it has."""
self[widget_state.id] = Serialized(widget_state)
def set_from_value(self, k: str, v: Any) -> None:
"""Set a widget's deserialized value, overwriting any existing value it has."""
self[k] = Value(v)
def set_widget_metadata(self, widget_meta: WidgetMetadata[Any]) -> None:
"""Set a widget's metadata, overwriting any existing metadata it has."""
self.widget_metadata[widget_meta.id] = widget_meta
def remove_stale_widgets(
self,
active_widget_ids: set[str],
fragment_ids_this_run: list[str] | None,
) -> None:
"""Remove widget state for stale widgets."""
self.states = {
k: v
for k, v in self.states.items()
if not _is_stale_widget(
self.widget_metadata.get(k),
active_widget_ids,
fragment_ids_this_run,
)
}
def get_serialized(self, k: str) -> WidgetStateProto | None:
"""Get the serialized value of the widget with the given id.
If the widget doesn't exist, return None. If the widget exists but
is not in serialized form, it will be serialized first.
"""
item = self.states.get(k)
if item is None:
# No such widget: return None.
return None
if isinstance(item, Serialized):
# Widget value is serialized: return it directly.
return item.value
# Widget value is not serialized: serialize it first!
metadata = self.widget_metadata.get(k)
if metadata is None:
# We're missing the widget's metadata. (Can this happen?)
return None
widget = WidgetStateProto()
widget.id = k
field = metadata.value_type
serialized = metadata.serializer(item.value)
if is_array_value_field_name(field):
arr = getattr(widget, field)
arr.data.extend(serialized)
elif field == "json_value":
setattr(widget, field, json.dumps(serialized))
elif field == "file_uploader_state_value":
widget.file_uploader_state_value.CopyFrom(serialized)
elif field == "string_trigger_value":
widget.string_trigger_value.CopyFrom(serialized)
elif field == "chat_input_value":
widget.chat_input_value.CopyFrom(serialized)
elif field is not None and serialized is not None:
# If the field is None, the widget value was cleared
# by the user and therefore is None. But we cannot
# set it to None here, since the proto properties are
# not nullable. So we just don't set it.
setattr(widget, field, serialized)
return widget
def as_widget_states(self) -> list[WidgetStateProto]:
"""Return a list of serialized widget values for each widget with a value."""
states = [
self.get_serialized(widget_id)
for widget_id in self.states.keys()
if self.get_serialized(widget_id)
]
states = cast("list[WidgetStateProto]", states)
return states
def call_callback(self, widget_id: str) -> None:
"""Call the given widget's callback and return the callback's
return value. If the widget has no callback, return None.
If the widget doesn't exist, raise an Exception.
"""
metadata = self.widget_metadata.get(widget_id)
assert metadata is not None
callback = metadata.callback
if callback is None:
return
args = metadata.callback_args or ()
kwargs = metadata.callback_kwargs or {}
callback(*args, **kwargs)
def _missing_key_error_message(key: str) -> str:
return (
f'st.session_state has no key "{key}". Did you forget to initialize it? '
f"More info: https://docs.streamlit.io/develop/concepts/architecture/session-state#initialization"
)
@dataclass
class KeyIdMapper:
"""A mapping of user-provided keys to element IDs.
It also maps element IDs to user-provided keys so that this reverse mapping
does not have to be computed ad-hoc.
All built-in dict-operations such as setting and deleting expect the key as the
argument, not the element ID.
"""
_key_id_mapping: dict[str, str] = field(default_factory=dict)
_id_key_mapping: dict[str, str] = field(default_factory=dict)
def __contains__(self, key: str) -> bool:
return key in self._key_id_mapping
def __setitem__(self, key: str, widget_id: Any) -> None:
self._key_id_mapping[key] = widget_id
self._id_key_mapping[widget_id] = key
def __delitem__(self, key: str) -> None:
self.delete(key)
@property
def id_key_mapping(self) -> dict[str, str]:
return self._id_key_mapping
def set_key_id_mapping(self, key_id_mapping: dict[str, str]) -> None:
self._key_id_mapping = key_id_mapping
self._id_key_mapping = {v: k for k, v in key_id_mapping.items()}
def get_id_from_key(self, key: str, default: Any = None) -> str:
return self._key_id_mapping.get(key, default)
def get_key_from_id(self, widget_id: str) -> str:
return self._id_key_mapping[widget_id]
def update(self, other: KeyIdMapper) -> None:
self._key_id_mapping.update(other._key_id_mapping)
self._id_key_mapping.update(other._id_key_mapping)
def clear(self):
self._key_id_mapping.clear()
self._id_key_mapping.clear()
def delete(self, key: str):
widget_id = self._key_id_mapping[key]
del self._key_id_mapping[key]
del self._id_key_mapping[widget_id]
@dataclass
class SessionState:
"""SessionState allows users to store values that persist between app
reruns.
Example
-------
>>> if "num_script_runs" not in st.session_state:
... st.session_state.num_script_runs = 0
>>> st.session_state.num_script_runs += 1
>>> st.write(st.session_state.num_script_runs) # writes 1
The next time your script runs, the value of
st.session_state.num_script_runs will be preserved.
>>> st.session_state.num_script_runs += 1
>>> st.write(st.session_state.num_script_runs) # writes 2
"""
# All the values from previous script runs, squished together to save memory
_old_state: dict[str, Any] = field(default_factory=dict)
# Values set in session state during the current script run, possibly for
# setting a widget's value. Keyed by a user provided string.
_new_session_state: dict[str, Any] = field(default_factory=dict)
# Widget values from the frontend, usually one changing prompted the script rerun
_new_widget_state: WStates = field(default_factory=WStates)
# Keys used for widgets will be eagerly converted to the matching element id
_key_id_mapper: KeyIdMapper = field(default_factory=KeyIdMapper)
# query params are stored in session state because query params will be tied with
# widget state at one point.
query_params: QueryParams = field(default_factory=QueryParams)
def __repr__(self):
return util.repr_(self)
# is it possible for a value to get through this without being deserialized?
def _compact_state(self) -> None:
"""Copy all current session_state and widget_state values into our
_old_state dict, and then clear our current session_state and
widget_state.
"""
for key_or_wid in self:
try:
self._old_state[key_or_wid] = self[key_or_wid]
except KeyError:
# handle key errors from widget state not having metadata gracefully
# https://github.com/streamlit/streamlit/issues/7206
pass
self._new_session_state.clear()
self._new_widget_state.clear()
def clear(self) -> None:
"""Reset self completely, clearing all current and old values."""
self._old_state.clear()
self._new_session_state.clear()
self._new_widget_state.clear()
self._key_id_mapper.clear()
@property
def filtered_state(self) -> dict[str, Any]:
"""The combined session and widget state, excluding keyless widgets."""
wid_key_map = self._key_id_mapper.id_key_mapping
state: dict[str, Any] = {}
# We can't write `for k, v in self.items()` here because doing so will
# run into a `KeyError` if widget metadata has been cleared (which
# happens when the streamlit server restarted or the cache was cleared),
# then we receive a widget's state from a browser.
for k in self._keys():
if not is_element_id(k) and not _is_internal_key(k):
state[k] = self[k]
elif is_keyed_element_id(k):
try:
key = wid_key_map[k]
state[key] = self[k]
except KeyError:
# Widget id no longer maps to a key, it is a not yet
# cleared value in old state for a reset widget
pass
return state
def _keys(self) -> set[str]:
"""All keys active in Session State, with widget keys converted
to widget ids when one is known. (This includes autogenerated keys
for widgets that don't have user_keys defined, and which aren't
exposed to user code).
"""
old_keys = {self._get_widget_id(k) for k in self._old_state.keys()}
new_widget_keys = set(self._new_widget_state.keys())
new_session_state_keys = {
self._get_widget_id(k) for k in self._new_session_state.keys()
}
return old_keys | new_widget_keys | new_session_state_keys
def is_new_state_value(self, user_key: str) -> bool:
"""True if a value with the given key is in the current session state."""
return user_key in self._new_session_state
def __iter__(self) -> Iterator[Any]:
"""Return an iterator over the keys of the SessionState.
This is a shortcut for `iter(self.keys())`.
"""
return iter(self._keys())
def __len__(self) -> int:
"""Return the number of items in SessionState."""
return len(self._keys())
def __getitem__(self, key: str) -> Any:
wid_key_map = self._key_id_mapper.id_key_mapping
widget_id = self._get_widget_id(key)
if widget_id in wid_key_map and widget_id == key:
# the "key" is a raw widget id, so get its associated user key for lookup
key = wid_key_map[widget_id]
try:
return self._getitem(widget_id, key)
except KeyError:
raise KeyError(_missing_key_error_message(key))
def _getitem(self, widget_id: str | None, user_key: str | None) -> Any:
"""Get the value of an entry in Session State, using either the
user-provided key or a widget id as appropriate for the internal dict
being accessed.
At least one of the arguments must have a value.
"""
assert user_key is not None or widget_id is not None
if user_key is not None:
try:
return self._new_session_state[user_key]
except KeyError:
pass
if widget_id is not None:
try:
return self._new_widget_state[widget_id]
except KeyError:
pass
# Typically, there won't be both a widget id and an associated state key in
# old state at the same time, so the order we check is arbitrary.
# The exception is if session state is set and then a later run has
# a widget created, so the widget id entry should be newer.
# The opposite case shouldn't happen, because setting the value of a widget
# through session state will result in the next widget state reflecting that
# value.
if widget_id is not None:
try:
return self._old_state[widget_id]
except KeyError:
pass
if user_key is not None:
try:
return self._old_state[user_key]
except KeyError:
pass
# We'll never get here
raise KeyError
def __setitem__(self, user_key: str, value: Any) -> None:
"""Set the value of the session_state entry with the given user_key.
If the key corresponds to a widget or form that's been instantiated
during the current script run, raise a StreamlitAPIException instead.
"""
ctx = get_script_run_ctx()
if ctx is not None:
widget_id = self._key_id_mapper.get_id_from_key(user_key, None)
widget_ids = ctx.widget_ids_this_run
form_ids = ctx.form_ids_this_run
if widget_id in widget_ids or user_key in form_ids:
raise StreamlitAPIException(
f"`st.session_state.{user_key}` cannot be modified after the widget"
f" with key `{user_key}` is instantiated."
)
self._new_session_state[user_key] = value
def __delitem__(self, key: str) -> None:
widget_id = self._get_widget_id(key)
if not (key in self or widget_id in self):
raise KeyError(_missing_key_error_message(key))
if key in self._new_session_state:
del self._new_session_state[key]
if key in self._old_state:
del self._old_state[key]
if key in self._key_id_mapper:
self._key_id_mapper.delete(key)
if widget_id in self._new_widget_state:
del self._new_widget_state[widget_id]
if widget_id in self._old_state:
del self._old_state[widget_id]
def set_widgets_from_proto(self, widget_states: WidgetStatesProto) -> None:
"""Set the value of all widgets represented in the given WidgetStatesProto."""
for state in widget_states.widgets:
self._new_widget_state.set_widget_from_proto(state)
def on_script_will_rerun(self, latest_widget_states: WidgetStatesProto) -> None:
"""Called by ScriptRunner before its script re-runs.
Update widget data and call callbacks on widgets whose value changed
between the previous and current script runs.
"""
# Clear any triggers that weren't reset because the script was disconnected
self._reset_triggers()
self._compact_state()
self.set_widgets_from_proto(latest_widget_states)
self._call_callbacks()
def _call_callbacks(self) -> None:
"""Call any callback associated with each widget whose value
changed between the previous and current script runs.
"""
from streamlit.runtime.scriptrunner import RerunException
changed_widget_ids = [
wid for wid in self._new_widget_state if self._widget_changed(wid)
]
for wid in changed_widget_ids:
try:
self._new_widget_state.call_callback(wid)
except RerunException:
st.warning("Calling st.rerun() within a callback is a no-op.")
def _widget_changed(self, widget_id: str) -> bool:
"""True if the given widget's value changed between the previous
script run and the current script run.
"""
new_value = self._new_widget_state.get(widget_id)
old_value = self._old_state.get(widget_id)
changed: bool = new_value != old_value
return changed
def on_script_finished(self, widget_ids_this_run: set[str]) -> None:
"""Called by ScriptRunner after its script finishes running.
Updates widgets to prepare for the next script run.
Parameters
----------
widget_ids_this_run: set[str]
The IDs of the widgets that were accessed during the script
run. Any widget state whose ID does *not* appear in this set
is considered "stale" and will be removed.
"""
self._reset_triggers()
self._remove_stale_widgets(widget_ids_this_run)
def _reset_triggers(self) -> None:
"""Set all trigger values in our state dictionary to False."""
for state_id in self._new_widget_state:
metadata = self._new_widget_state.widget_metadata.get(state_id)
if metadata is not None:
if metadata.value_type == "trigger_value":
self._new_widget_state[state_id] = Value(False)
elif metadata.value_type == "string_trigger_value":
self._new_widget_state[state_id] = Value(None)
elif metadata.value_type == "chat_input_value":
self._new_widget_state[state_id] = Value(None)
for state_id in self._old_state:
metadata = self._new_widget_state.widget_metadata.get(state_id)
if metadata is not None:
if metadata.value_type == "trigger_value":
self._old_state[state_id] = False
elif metadata.value_type == "string_trigger_value":
self._old_state[state_id] = None
elif metadata.value_type == "chat_input_value":
self._old_state[state_id] = None
def _remove_stale_widgets(self, active_widget_ids: set[str]) -> None:
"""Remove widget state for widgets whose ids aren't in `active_widget_ids`."""
ctx = get_script_run_ctx()
if ctx is None:
return
self._new_widget_state.remove_stale_widgets(
active_widget_ids,
ctx.fragment_ids_this_run,
)
# Remove entries from _old_state corresponding to
# widgets not in widget_ids.
self._old_state = {
k: v
for k, v in self._old_state.items()
if (
not is_element_id(k)
or not _is_stale_widget(
self._new_widget_state.widget_metadata.get(k),
active_widget_ids,
ctx.fragment_ids_this_run,
)
)
}
def _set_widget_metadata(self, widget_metadata: WidgetMetadata[Any]) -> None:
"""Set a widget's metadata."""
widget_id = widget_metadata.id
self._new_widget_state.widget_metadata[widget_id] = widget_metadata
def get_widget_states(self) -> list[WidgetStateProto]:
"""Return a list of serialized widget values for each widget with a value."""
return self._new_widget_state.as_widget_states()
def _get_widget_id(self, k: str) -> str:
"""Turns a value that might be a widget id or a user provided key into
an appropriate widget id.
"""
return self._key_id_mapper.get_id_from_key(k, k)
def _set_key_widget_mapping(self, widget_id: str, user_key: str) -> None:
self._key_id_mapper[user_key] = widget_id
def register_widget(
self, metadata: WidgetMetadata[T], user_key: str | None
) -> RegisterWidgetResult[T]:
"""Register a widget with the SessionState.
Returns
-------
RegisterWidgetResult[T]
Contains the widget's current value, and a bool that will be True
if the frontend needs to be updated with the current value.
"""
widget_id = metadata.id
self._set_widget_metadata(metadata)
if user_key is not None:
# If the widget has a user_key, update its user_key:widget_id mapping
self._set_key_widget_mapping(widget_id, user_key)
if widget_id not in self and (user_key is None or user_key not in self):
# This is the first time the widget is registered, so we save its
# value in widget state.
deserializer = metadata.deserializer
initial_widget_value = deepcopy(deserializer(None, metadata.id))
self._new_widget_state.set_from_value(widget_id, initial_widget_value)
# Get the current value of the widget for use as its return value.
# We return a copy, so that reference types can't be accidentally
# mutated by user code.
widget_value = cast("T", self[widget_id])
widget_value = deepcopy(widget_value)
# widget_value_changed indicates to the caller that the widget's
# current value is different from what is in the frontend.
widget_value_changed = user_key is not None and self.is_new_state_value(
user_key
)
return RegisterWidgetResult(widget_value, widget_value_changed)
def __contains__(self, key: str) -> bool:
try:
self[key]
except KeyError:
return False
else:
return True
def get_stats(self) -> list[CacheStat]:
# Lazy-load vendored package to prevent import of numpy
from streamlit.vendor.pympler.asizeof import asizeof
stat = CacheStat("st_session_state", "", asizeof(self))
return [stat]
def _check_serializable(self) -> None:
"""Verify that everything added to session state can be serialized.
We use pickleability as the metric for serializability, and test for
pickleability by just trying it.
"""
for k in self:
try:
pickle.dumps(self[k])
except Exception as e:
err_msg = f"""Cannot serialize the value (of type `{type(self[k])}`) of '{k}' in st.session_state.
Streamlit has been configured to use [pickle](https://docs.python.org/3/library/pickle.html) to
serialize session_state values. Please convert the value to a pickle-serializable type. To learn
more about this behavior, see [our docs](https://docs.streamlit.io/knowledge-base/using-streamlit/serializable-session-state). """
raise UnserializableSessionStateError(err_msg) from e
def maybe_check_serializable(self) -> None:
"""Verify that session state can be serialized, if the relevant config
option is set.
See `_check_serializable` for details.
"""
if config.get_option("runner.enforceSerializableSessionState"):
self._check_serializable()
def _is_internal_key(key: str) -> bool:
return key.startswith(STREAMLIT_INTERNAL_KEY_PREFIX)
def _is_stale_widget(
metadata: WidgetMetadata[Any] | None,
active_widget_ids: set[str],
fragment_ids_this_run: list[str] | None,
) -> bool:
if not metadata:
return True
elif metadata.id in active_widget_ids:
return False
# If we're running 1 or more fragments, but this widget is unrelated to any of the
# fragments that we're running, then it should not be marked as stale as its value
# may still be needed for a future fragment run or full script run.
elif fragment_ids_this_run and metadata.fragment_id not in fragment_ids_this_run:
return False
return True
@dataclass
class SessionStateStatProvider(CacheStatsProvider):
_session_mgr: SessionManager
def get_stats(self) -> list[CacheStat]:
stats: list[CacheStat] = []
for session_info in self._session_mgr.list_active_sessions():
session_state = session_info.session.session_state
stats.extend(session_state.get_stats())
return group_stats(stats)

View File

@@ -0,0 +1,153 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Iterator, MutableMapping
from typing import Any, Final
from streamlit import logger as _logger
from streamlit import runtime
from streamlit.elements.lib.utils import Key
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.state.common import require_valid_user_key
from streamlit.runtime.state.safe_session_state import SafeSessionState
from streamlit.runtime.state.session_state import SessionState
_LOGGER: Final = _logger.get_logger(__name__)
_state_use_warning_already_displayed: bool = False
# The mock session state is used as a fallback if the script is run without `streamlit run`
_mock_session_state: SafeSessionState | None = None
def get_session_state() -> SafeSessionState:
"""Get the SessionState object for the current session.
Note that in streamlit scripts, this function should not be called
directly. Instead, SessionState objects should be accessed via
st.session_state.
"""
global _state_use_warning_already_displayed
from streamlit.runtime.scriptrunner_utils.script_run_context import (
get_script_run_ctx,
)
ctx = get_script_run_ctx()
# If there is no script run context because the script is run bare, we
# use a global mock session state version to allow bare script execution (via python script.py)
if ctx is None:
if not _state_use_warning_already_displayed:
_state_use_warning_already_displayed = True
if not runtime.exists():
_LOGGER.warning(
"Session state does not function when running a script without `streamlit run`"
)
global _mock_session_state
if _mock_session_state is None:
# Lazy initialize the mock session state
_mock_session_state = SafeSessionState(SessionState(), lambda: None)
return _mock_session_state
return ctx.session_state
class SessionStateProxy(MutableMapping[Key, Any]):
"""A stateless singleton that proxies `st.session_state` interactions
to the current script thread's SessionState instance.
The proxy API differs slightly from SessionState: it does not allow
callers to get, set, or iterate over "keyless" widgets (that is, widgets
that were created without a user_key, and have autogenerated keys).
"""
def __iter__(self) -> Iterator[Any]:
"""Iterator over user state and keyed widget values."""
# TODO: this is unsafe if fastReruns is true! Let's deprecate/remove.
return iter(get_session_state().filtered_state)
def __len__(self) -> int:
"""Number of user state and keyed widget values in session_state."""
return len(get_session_state().filtered_state)
def __str__(self) -> str:
"""String representation of user state and keyed widget values."""
return str(get_session_state().filtered_state)
def __getitem__(self, key: Key) -> Any:
"""Return the state or widget value with the given key.
Raises
------
StreamlitAPIException
If the key is not a valid SessionState user key.
"""
key = str(key)
require_valid_user_key(key)
return get_session_state()[key]
@gather_metrics("session_state.set_item")
def __setitem__(self, key: Key, value: Any) -> None:
"""Set the value of the given key.
Raises
------
StreamlitAPIException
If the key is not a valid SessionState user key.
"""
key = str(key)
require_valid_user_key(key)
get_session_state()[key] = value
def __delitem__(self, key: Key) -> None:
"""Delete the value with the given key.
Raises
------
StreamlitAPIException
If the key is not a valid SessionState user key.
"""
key = str(key)
require_valid_user_key(key)
del get_session_state()[key]
def __getattr__(self, key: str) -> Any:
try:
return self[key]
except KeyError:
raise AttributeError(_missing_attr_error_message(key))
@gather_metrics("session_state.set_attr")
def __setattr__(self, key: str, value: Any) -> None:
self[key] = value
def __delattr__(self, key: str) -> None:
try:
del self[key]
except KeyError:
raise AttributeError(_missing_attr_error_message(key))
def to_dict(self) -> dict[str, Any]:
"""Return a dict containing all session_state and keyed widget values."""
return get_session_state().filtered_state
def _missing_attr_error_message(attr_name: str) -> str:
return (
f'st.session_state has no attribute "{attr_name}". Did you forget to initialize it? '
f"More info: https://docs.streamlit.io/develop/concepts/architecture/session-state#initialization"
)

View File

@@ -0,0 +1,135 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from streamlit.runtime.state.common import (
RegisterWidgetResult,
T,
ValueFieldName,
WidgetArgs,
WidgetCallback,
WidgetDeserializer,
WidgetKwargs,
WidgetMetadata,
WidgetSerializer,
user_key_from_element_id,
)
if TYPE_CHECKING:
from streamlit.runtime.scriptrunner import ScriptRunContext
def register_widget(
element_id: str,
*,
deserializer: WidgetDeserializer[T],
serializer: WidgetSerializer[T],
ctx: ScriptRunContext | None,
on_change_handler: WidgetCallback | None = None,
args: WidgetArgs | None = None,
kwargs: WidgetKwargs | None = None,
value_type: ValueFieldName,
) -> RegisterWidgetResult[T]:
"""Register a widget with Streamlit, and return its current value.
NOTE: This function should be called after the proto has been filled.
Parameters
----------
element_id : str
The id of the element. Must be unique.
deserializer : WidgetDeserializer[T]
Called to convert a widget's protobuf value to the value returned by
its st.<widget_name> function.
serializer : WidgetSerializer[T]
Called to convert a widget's value to its protobuf representation.
ctx : ScriptRunContext or None
Used to ensure uniqueness of widget IDs, and to look up widget values.
on_change_handler : WidgetCallback or None
An optional callback invoked when the widget's value changes.
args : WidgetArgs or None
args to pass to on_change_handler when invoked
kwargs : WidgetKwargs or None
kwargs to pass to on_change_handler when invoked
value_type: ValueType
The value_type the widget is going to use.
We use this information to start with a best-effort guess for the value_type
of each widget. Once we actually receive a proto for a widget from the
frontend, the guess is updated to be the correct type. Unfortunately, we're
not able to always rely on the proto as the type may be needed earlier.
Thankfully, in these cases (when value_type == "trigger_value"), the static
table here being slightly inaccurate should never pose a problem.
Returns
-------
register_widget_result : RegisterWidgetResult[T]
Provides information on which value to return to the widget caller,
and whether the UI needs updating.
- Unhappy path:
- Our ScriptRunContext doesn't exist (meaning that we're running
as a "bare script" outside streamlit).
- We are disconnected from the SessionState instance.
In both cases we'll return a fallback RegisterWidgetResult[T].
- Happy path:
- The widget has already been registered on a previous run but the
user hasn't interacted with it on the client. The widget will have
the default value it was first created with. We then return a
RegisterWidgetResult[T], containing this value.
- The widget has already been registered and the user *has*
interacted with it. The widget will have that most recent
user-specified value. We then return a RegisterWidgetResult[T],
containing this value.
For both paths a widget return value is provided, allowing the widgets
to be used in a non-streamlit setting.
"""
# Create the widget's updated metadata, and register it with session_state.
metadata = WidgetMetadata(
element_id,
deserializer,
serializer,
value_type=value_type,
callback=on_change_handler,
callback_args=args,
callback_kwargs=kwargs,
fragment_id=ctx.current_fragment_id if ctx else None,
)
return register_widget_from_metadata(metadata, ctx)
def register_widget_from_metadata(
metadata: WidgetMetadata[T],
ctx: ScriptRunContext | None,
) -> RegisterWidgetResult[T]:
"""Register a widget and return its value, using an already constructed
`WidgetMetadata`.
This is split out from `register_widget` to allow caching code to replay
widgets by saving and reusing the completed metadata.
See `register_widget` for details on what this returns.
"""
if ctx is None:
# Early-out if we don't have a script run context (which probably means
# we're running as a "bare" Python script, and not via `streamlit run`).
return RegisterWidgetResult.failure(deserializer=metadata.deserializer)
widget_id = metadata.id
user_key = user_key_from_element_id(widget_id)
return ctx.session_state.register_widget(metadata, user_key)