Mise à jour de Monitor.py et autres scripts
This commit is contained in:
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from streamlit.runtime.state.common import WidgetArgs, WidgetCallback, WidgetKwargs
|
||||
from streamlit.runtime.state.query_params_proxy import QueryParamsProxy
|
||||
from streamlit.runtime.state.safe_session_state import SafeSessionState
|
||||
from streamlit.runtime.state.session_state import (
|
||||
SCRIPT_RUN_WITHOUT_ERRORS_KEY,
|
||||
SessionState,
|
||||
SessionStateStatProvider,
|
||||
)
|
||||
from streamlit.runtime.state.session_state_proxy import (
|
||||
SessionStateProxy,
|
||||
get_session_state,
|
||||
)
|
||||
from streamlit.runtime.state.widgets import register_widget
|
||||
|
||||
__all__ = [
|
||||
"WidgetArgs",
|
||||
"WidgetCallback",
|
||||
"WidgetKwargs",
|
||||
"QueryParamsProxy",
|
||||
"SafeSessionState",
|
||||
"SCRIPT_RUN_WITHOUT_ERRORS_KEY",
|
||||
"SessionState",
|
||||
"SessionStateStatProvider",
|
||||
"SessionStateProxy",
|
||||
"get_session_state",
|
||||
"register_widget",
|
||||
]
|
||||
@@ -0,0 +1,191 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Functions and data structures shared by session_state.py and widgets.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Final,
|
||||
Generic,
|
||||
Literal,
|
||||
TypeVar,
|
||||
cast,
|
||||
get_args,
|
||||
)
|
||||
|
||||
from typing_extensions import TypeAlias, TypeGuard
|
||||
|
||||
from streamlit import util
|
||||
from streamlit.errors import (
|
||||
StreamlitAPIException,
|
||||
)
|
||||
|
||||
GENERATED_ELEMENT_ID_PREFIX: Final = "$$ID"
|
||||
TESTING_KEY = "$$STREAMLIT_INTERNAL_KEY_TESTING"
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
|
||||
|
||||
WidgetArgs: TypeAlias = tuple[Any, ...]
|
||||
WidgetKwargs: TypeAlias = dict[str, Any]
|
||||
WidgetCallback: TypeAlias = Callable[..., None]
|
||||
|
||||
# A deserializer receives the value from whatever field is set on the
|
||||
# WidgetState proto, and returns a regular python value. A serializer
|
||||
# receives a regular python value, and returns something suitable for
|
||||
# a value field on WidgetState proto. They should be inverses.
|
||||
WidgetDeserializer: TypeAlias = Callable[[Any, str], T]
|
||||
WidgetSerializer: TypeAlias = Callable[[T], Any]
|
||||
|
||||
# The array value field names are part of the larger set of possible value
|
||||
# field names. See the explanation for said set below. The message types
|
||||
# associated with these fields are distinguished by storing data in a `data`
|
||||
# field in their messages, meaning they need special treatment in certain
|
||||
# circumstances. Hence, they need their own, dedicated, sub-type.
|
||||
ArrayValueFieldName: TypeAlias = Literal[
|
||||
"double_array_value",
|
||||
"int_array_value",
|
||||
"string_array_value",
|
||||
]
|
||||
|
||||
# A frozenset containing the allowed values of the ArrayValueFieldName type.
|
||||
# Useful for membership checking.
|
||||
_ARRAY_VALUE_FIELD_NAMES: Final = frozenset(
|
||||
cast(
|
||||
"tuple[ArrayValueFieldName, ...]",
|
||||
# NOTE: get_args is not recursive, so this only works as long as
|
||||
# ArrayValueFieldName remains flat.
|
||||
get_args(ArrayValueFieldName),
|
||||
)
|
||||
)
|
||||
|
||||
# These are the possible field names that can be set in the `value` oneof-field
|
||||
# of the WidgetState message (schema found in .proto/WidgetStates.proto).
|
||||
# We need these as a literal type to ensure correspondence with the protobuf
|
||||
# schema in certain parts of the python code.
|
||||
# TODO(harahu): It would be preferable if this type was automatically derived
|
||||
# from the protobuf schema, rather than manually maintained. Not sure how to
|
||||
# achieve that, though.
|
||||
ValueFieldName: TypeAlias = Literal[
|
||||
ArrayValueFieldName,
|
||||
"arrow_value",
|
||||
"bool_value",
|
||||
"bytes_value",
|
||||
"double_value",
|
||||
"file_uploader_state_value",
|
||||
"int_value",
|
||||
"json_value",
|
||||
"string_value",
|
||||
"trigger_value",
|
||||
"string_trigger_value",
|
||||
"chat_input_value",
|
||||
]
|
||||
|
||||
|
||||
def is_array_value_field_name(obj: object) -> TypeGuard[ArrayValueFieldName]:
|
||||
return obj in _ARRAY_VALUE_FIELD_NAMES
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WidgetMetadata(Generic[T]):
|
||||
"""Metadata associated with a single widget. Immutable."""
|
||||
|
||||
id: str
|
||||
deserializer: WidgetDeserializer[T] = field(repr=False)
|
||||
serializer: WidgetSerializer[T] = field(repr=False)
|
||||
value_type: ValueFieldName
|
||||
|
||||
# An optional user-code callback invoked when the widget's value changes.
|
||||
# Widget callbacks are called at the start of a script run, before the
|
||||
# body of the script is executed.
|
||||
callback: WidgetCallback | None = None
|
||||
callback_args: WidgetArgs | None = None
|
||||
callback_kwargs: WidgetKwargs | None = None
|
||||
|
||||
fragment_id: str | None = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegisterWidgetResult(Generic[T_co]):
|
||||
"""Result returned by the `register_widget` family of functions/methods.
|
||||
|
||||
Should be usable by widget code to determine what value to return, and
|
||||
whether to update the UI.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value : T_co
|
||||
The widget's current value, or, in cases where the true widget value
|
||||
could not be determined, an appropriate fallback value.
|
||||
|
||||
This value should be returned by the widget call.
|
||||
value_changed : bool
|
||||
True if the widget's value is different from the value most recently
|
||||
returned from the frontend.
|
||||
|
||||
Implies an update to the frontend is needed.
|
||||
"""
|
||||
|
||||
value: T_co
|
||||
value_changed: bool
|
||||
|
||||
@classmethod
|
||||
def failure(
|
||||
cls, deserializer: WidgetDeserializer[T_co]
|
||||
) -> RegisterWidgetResult[T_co]:
|
||||
"""The canonical way to construct a RegisterWidgetResult in cases
|
||||
where the true widget value could not be determined.
|
||||
"""
|
||||
return cls(value=deserializer(None, ""), value_changed=False)
|
||||
|
||||
|
||||
def user_key_from_element_id(element_id: str) -> str | None:
|
||||
"""Return the user key portion of a element id, or None if the id does not
|
||||
have a user key.
|
||||
|
||||
TODO This will incorrectly indicate no user key if the user actually provides
|
||||
"None" as a key, but we can't avoid this kind of problem while storing the
|
||||
string representation of the no-user-key sentinel as part of the element id.
|
||||
"""
|
||||
user_key: str | None = element_id.split("-", maxsplit=2)[-1]
|
||||
return None if user_key == "None" else user_key
|
||||
|
||||
|
||||
def is_element_id(key: str) -> bool:
|
||||
"""True if the given session_state key has the structure of a element ID."""
|
||||
return key.startswith(GENERATED_ELEMENT_ID_PREFIX)
|
||||
|
||||
|
||||
def is_keyed_element_id(key: str) -> bool:
|
||||
"""True if the given session_state key has the structure of a element ID
|
||||
with a user_key.
|
||||
"""
|
||||
return is_element_id(key) and not key.endswith("-None")
|
||||
|
||||
|
||||
def require_valid_user_key(key: str) -> None:
|
||||
"""Raise an Exception if the given user_key is invalid."""
|
||||
if is_element_id(key):
|
||||
raise StreamlitAPIException(
|
||||
f"Keys beginning with {GENERATED_ELEMENT_ID_PREFIX} are reserved."
|
||||
)
|
||||
@@ -0,0 +1,205 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Iterator, MutableMapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Final
|
||||
from urllib import parse
|
||||
|
||||
from streamlit.errors import StreamlitAPIException
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import SupportsKeysAndGetItem
|
||||
|
||||
|
||||
EMBED_QUERY_PARAM: Final[str] = "embed"
|
||||
EMBED_OPTIONS_QUERY_PARAM: Final[str] = "embed_options"
|
||||
EMBED_QUERY_PARAMS_KEYS: Final[list[str]] = [
|
||||
EMBED_QUERY_PARAM,
|
||||
EMBED_OPTIONS_QUERY_PARAM,
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryParams(MutableMapping[str, str]):
|
||||
"""A lightweight wrapper of a dict that sends forwardMsgs when state changes.
|
||||
It stores str keys with str and List[str] values.
|
||||
"""
|
||||
|
||||
_query_params: dict[str, list[str] | str] = field(default_factory=dict)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
self._ensure_single_query_api_used()
|
||||
|
||||
return iter(
|
||||
key
|
||||
for key in self._query_params.keys()
|
||||
if key not in EMBED_QUERY_PARAMS_KEYS
|
||||
)
|
||||
|
||||
def __getitem__(self, key: str) -> str:
|
||||
"""Retrieves a value for a given key in query parameters.
|
||||
Returns the last item in a list or an empty string if empty.
|
||||
If the key is not present, raise KeyError.
|
||||
"""
|
||||
self._ensure_single_query_api_used()
|
||||
try:
|
||||
if key in EMBED_QUERY_PARAMS_KEYS:
|
||||
raise KeyError(missing_key_error_message(key))
|
||||
value = self._query_params[key]
|
||||
if isinstance(value, list):
|
||||
if len(value) == 0:
|
||||
return ""
|
||||
else:
|
||||
# Return the last value to mimic Tornado's behavior
|
||||
# https://www.tornadoweb.org/en/stable/web.html#tornado.web.RequestHandler.get_query_argument
|
||||
return value[-1]
|
||||
return value
|
||||
except KeyError:
|
||||
raise KeyError(missing_key_error_message(key))
|
||||
|
||||
def __setitem__(self, key: str, value: str | Iterable[str]) -> None:
|
||||
self._ensure_single_query_api_used()
|
||||
self.__set_item_internal(key, value)
|
||||
self._send_query_param_msg()
|
||||
|
||||
def __set_item_internal(self, key: str, value: str | Iterable[str]) -> None:
|
||||
if isinstance(value, dict):
|
||||
raise StreamlitAPIException(
|
||||
f"You cannot set a query params key `{key}` to a dictionary."
|
||||
)
|
||||
|
||||
if key in EMBED_QUERY_PARAMS_KEYS:
|
||||
raise StreamlitAPIException(
|
||||
"Query param embed and embed_options (case-insensitive) cannot be set programmatically."
|
||||
)
|
||||
# Type checking users should handle the string serialization themselves
|
||||
# We will accept any type for the list and serialize to str just in case
|
||||
if isinstance(value, Iterable) and not isinstance(value, str):
|
||||
self._query_params[key] = [str(item) for item in value]
|
||||
else:
|
||||
self._query_params[key] = str(value)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
self._ensure_single_query_api_used()
|
||||
try:
|
||||
if key in EMBED_QUERY_PARAMS_KEYS:
|
||||
raise KeyError(missing_key_error_message(key))
|
||||
del self._query_params[key]
|
||||
self._send_query_param_msg()
|
||||
except KeyError:
|
||||
raise KeyError(missing_key_error_message(key))
|
||||
|
||||
def update(
|
||||
self,
|
||||
other: Iterable[tuple[str, str | Iterable[str]]]
|
||||
| SupportsKeysAndGetItem[str, str | Iterable[str]] = (),
|
||||
/,
|
||||
**kwds: str,
|
||||
):
|
||||
# This overrides the `update` provided by MutableMapping
|
||||
# to ensure only one one ForwardMsg is sent.
|
||||
self._ensure_single_query_api_used()
|
||||
if hasattr(other, "keys") and hasattr(other, "__getitem__"):
|
||||
for key in other.keys():
|
||||
self.__set_item_internal(key, other[key])
|
||||
else:
|
||||
for key, value in other:
|
||||
self.__set_item_internal(key, value)
|
||||
for key, value in kwds.items():
|
||||
self.__set_item_internal(key, value)
|
||||
self._send_query_param_msg()
|
||||
|
||||
def get_all(self, key: str) -> list[str]:
|
||||
self._ensure_single_query_api_used()
|
||||
if key not in self._query_params or key in EMBED_QUERY_PARAMS_KEYS:
|
||||
return []
|
||||
value = self._query_params[key]
|
||||
return value if isinstance(value, list) else [value]
|
||||
|
||||
def __len__(self) -> int:
|
||||
self._ensure_single_query_api_used()
|
||||
return len(
|
||||
{key for key in self._query_params if key not in EMBED_QUERY_PARAMS_KEYS}
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
self._ensure_single_query_api_used()
|
||||
return str(self._query_params)
|
||||
|
||||
def _send_query_param_msg(self) -> None:
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is None:
|
||||
return
|
||||
self._ensure_single_query_api_used()
|
||||
|
||||
msg = ForwardMsg()
|
||||
msg.page_info_changed.query_string = parse.urlencode(
|
||||
self._query_params, doseq=True
|
||||
)
|
||||
ctx.query_string = msg.page_info_changed.query_string
|
||||
ctx.enqueue(msg)
|
||||
|
||||
def clear(self) -> None:
|
||||
self._ensure_single_query_api_used()
|
||||
self.clear_with_no_forward_msg(preserve_embed=True)
|
||||
self._send_query_param_msg()
|
||||
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
self._ensure_single_query_api_used()
|
||||
# return the last query param if multiple values are set
|
||||
return {
|
||||
key: self[key]
|
||||
for key in self._query_params
|
||||
if key not in EMBED_QUERY_PARAMS_KEYS
|
||||
}
|
||||
|
||||
def from_dict(
|
||||
self,
|
||||
_dict: Iterable[tuple[str, str | Iterable[str]]]
|
||||
| SupportsKeysAndGetItem[str, str | Iterable[str]],
|
||||
):
|
||||
self._ensure_single_query_api_used()
|
||||
old_value = self._query_params.copy()
|
||||
self.clear_with_no_forward_msg(preserve_embed=True)
|
||||
try:
|
||||
self.update(_dict)
|
||||
except StreamlitAPIException:
|
||||
# restore the original from before we made any changes.
|
||||
self._query_params = old_value
|
||||
raise
|
||||
|
||||
def set_with_no_forward_msg(self, key: str, val: list[str] | str) -> None:
|
||||
self._query_params[key] = val
|
||||
|
||||
def clear_with_no_forward_msg(self, preserve_embed: bool = False) -> None:
|
||||
self._query_params = {
|
||||
key: value
|
||||
for key, value in self._query_params.items()
|
||||
if key in EMBED_QUERY_PARAMS_KEYS and preserve_embed
|
||||
}
|
||||
|
||||
def _ensure_single_query_api_used(self):
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is None:
|
||||
return
|
||||
ctx.mark_production_query_params_used()
|
||||
|
||||
|
||||
def missing_key_error_message(key: str) -> str:
|
||||
return f'st.query_params has no key "{key}".'
|
||||
@@ -0,0 +1,218 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Iterator, MutableMapping
|
||||
from typing import TYPE_CHECKING, overload
|
||||
|
||||
from streamlit.runtime.metrics_util import gather_metrics
|
||||
from streamlit.runtime.state.session_state_proxy import get_session_state
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import SupportsKeysAndGetItem
|
||||
|
||||
|
||||
class QueryParamsProxy(MutableMapping[str, str]):
|
||||
"""
|
||||
A stateless singleton that proxies ``st.query_params`` interactions
|
||||
to the current script thread's QueryParams instance.
|
||||
"""
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
with get_session_state().query_params() as qp:
|
||||
return iter(qp)
|
||||
|
||||
def __len__(self) -> int:
|
||||
with get_session_state().query_params() as qp:
|
||||
return len(qp)
|
||||
|
||||
def __str__(self) -> str:
|
||||
with get_session_state().query_params() as qp:
|
||||
return str(qp)
|
||||
|
||||
@gather_metrics("query_params.get_item")
|
||||
def __getitem__(self, key: str) -> str:
|
||||
with get_session_state().query_params() as qp:
|
||||
try:
|
||||
return qp[key]
|
||||
except KeyError:
|
||||
raise KeyError(self.missing_key_error_message(key))
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
with get_session_state().query_params() as qp:
|
||||
del qp[key]
|
||||
|
||||
@gather_metrics("query_params.set_item")
|
||||
def __setitem__(self, key: str, value: str | Iterable[str]) -> None:
|
||||
with get_session_state().query_params() as qp:
|
||||
qp[key] = value
|
||||
|
||||
@gather_metrics("query_params.get_attr")
|
||||
def __getattr__(self, key: str) -> str:
|
||||
with get_session_state().query_params() as qp:
|
||||
try:
|
||||
return qp[key]
|
||||
except KeyError:
|
||||
raise AttributeError(self.missing_attr_error_message(key))
|
||||
|
||||
def __delattr__(self, key: str) -> None:
|
||||
with get_session_state().query_params() as qp:
|
||||
try:
|
||||
del qp[key]
|
||||
except KeyError:
|
||||
raise AttributeError(self.missing_key_error_message(key))
|
||||
|
||||
@overload
|
||||
def update(
|
||||
self, mapping: SupportsKeysAndGetItem[str, str | Iterable[str]], /, **kwds: str
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def update(
|
||||
self, keys_and_values: Iterable[tuple[str, str | Iterable[str]]], /, **kwds: str
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def update(self, **kwds: str | Iterable[str]) -> None: ...
|
||||
|
||||
def update(self, other=(), /, **kwds):
|
||||
"""
|
||||
Update one or more values in query_params at once from a dictionary or
|
||||
dictionary-like object.
|
||||
|
||||
See `Mapping.update()` from Python's `collections` library.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
other: SupportsKeysAndGetItem[str, str] | Iterable[tuple[str, str]]
|
||||
A dictionary or mapping of strings to strings.
|
||||
**kwds: str
|
||||
Additional key/value pairs to update passed as keyword arguments.
|
||||
"""
|
||||
with get_session_state().query_params() as qp:
|
||||
qp.update(other, **kwds)
|
||||
|
||||
@gather_metrics("query_params.set_attr")
|
||||
def __setattr__(self, key: str, value: str | Iterable[str]) -> None:
|
||||
with get_session_state().query_params() as qp:
|
||||
qp[key] = value
|
||||
|
||||
@gather_metrics("query_params.get_all")
|
||||
def get_all(self, key: str) -> list[str]:
|
||||
"""
|
||||
Get a list of all query parameter values associated to a given key.
|
||||
|
||||
When a key is repeated as a query parameter within the URL, this method
|
||||
allows all values to be obtained. In contrast, dict-like methods only
|
||||
retrieve the last value when a key is repeated in the URL.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key: str
|
||||
The label of the query parameter in the URL.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[str]
|
||||
A list of values associated to the given key. May return zero, one,
|
||||
or multiple values.
|
||||
"""
|
||||
with get_session_state().query_params() as qp:
|
||||
return qp.get_all(key)
|
||||
|
||||
@gather_metrics("query_params.clear")
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all query parameters from the URL of the app.
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
with get_session_state().query_params() as qp:
|
||||
qp.clear()
|
||||
|
||||
@gather_metrics("query_params.to_dict")
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
"""
|
||||
Get all query parameters as a dictionary.
|
||||
|
||||
This method primarily exists for internal use and is not needed for
|
||||
most cases. ``st.query_params`` returns an object that inherits from
|
||||
``dict`` by default.
|
||||
|
||||
When a key is repeated as a query parameter within the URL, this method
|
||||
will return only the last value of each unique key.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str,str]
|
||||
A dictionary of the current query paramters in the app's URL.
|
||||
"""
|
||||
with get_session_state().query_params() as qp:
|
||||
return qp.to_dict()
|
||||
|
||||
@overload
|
||||
def from_dict(
|
||||
self, keys_and_values: Iterable[tuple[str, str | Iterable[str]]]
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def from_dict(
|
||||
self, mapping: SupportsKeysAndGetItem[str, str | Iterable[str]]
|
||||
) -> None: ...
|
||||
|
||||
@gather_metrics("query_params.from_dict")
|
||||
def from_dict(self, params):
|
||||
"""
|
||||
Set all of the query parameters from a dictionary or dictionary-like object.
|
||||
|
||||
This method primarily exists for advanced users who want to control
|
||||
multiple query parameters in a single update. To set individual query
|
||||
parameters, use key or attribute notation instead.
|
||||
|
||||
This method inherits limitations from ``st.query_params`` and can't be
|
||||
used to set embedding options as described in `Embed your app \
|
||||
<https://docs.streamlit.io/deploy/streamlit-community-cloud/share-your-app/embed-your-app#embed-options>`_.
|
||||
|
||||
To handle repeated keys, the value in a key-value pair should be a list.
|
||||
|
||||
.. note::
|
||||
``.from_dict()`` is not a direct inverse of ``.to_dict()`` if
|
||||
you are working with repeated keys. A true inverse operation is
|
||||
``{key: st.query_params.get_all(key) for key in st.query_params}``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
params: dict
|
||||
A dictionary used to replace the current query parameters.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> st.query_params.from_dict({"foo": "bar", "baz": [1, "two"]})
|
||||
|
||||
"""
|
||||
with get_session_state().query_params() as qp:
|
||||
return qp.from_dict(params)
|
||||
|
||||
@staticmethod
|
||||
def missing_key_error_message(key: str) -> str:
|
||||
return f'st.query_params has no key "{key}".'
|
||||
|
||||
@staticmethod
|
||||
def missing_attr_error_message(key: str) -> str:
|
||||
return f'st.query_params has no attribute "{key}".'
|
||||
@@ -0,0 +1,138 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
from streamlit.proto.WidgetStates_pb2 import WidgetState as WidgetStateProto
|
||||
from streamlit.proto.WidgetStates_pb2 import WidgetStates as WidgetStatesProto
|
||||
from streamlit.runtime.state.common import RegisterWidgetResult, T, WidgetMetadata
|
||||
from streamlit.runtime.state.query_params import QueryParams
|
||||
from streamlit.runtime.state.session_state import SessionState
|
||||
|
||||
|
||||
class SafeSessionState:
|
||||
"""Thread-safe wrapper around SessionState.
|
||||
|
||||
When AppSession gets a re-run request, it can interrupt its existing
|
||||
ScriptRunner and spin up a new ScriptRunner to handle the request.
|
||||
When this happens, the existing ScriptRunner will continue executing
|
||||
its script until it reaches a yield point - but during this time, it
|
||||
must not mutate its SessionState.
|
||||
"""
|
||||
|
||||
_state: SessionState
|
||||
_lock: threading.RLock
|
||||
_yield_callback: Callable[[], None]
|
||||
|
||||
def __init__(self, state: SessionState, yield_callback: Callable[[], None]):
|
||||
# Fields must be set using the object's setattr method to avoid
|
||||
# infinite recursion from trying to look up the fields we're setting.
|
||||
object.__setattr__(self, "_state", state)
|
||||
# TODO: we'd prefer this be a threading.Lock instead of RLock -
|
||||
# but `call_callbacks` first needs to be rewritten.
|
||||
object.__setattr__(self, "_lock", threading.RLock())
|
||||
object.__setattr__(self, "_yield_callback", yield_callback)
|
||||
|
||||
def register_widget(
|
||||
self, metadata: WidgetMetadata[T], user_key: str | None
|
||||
) -> RegisterWidgetResult[T]:
|
||||
self._yield_callback()
|
||||
with self._lock:
|
||||
return self._state.register_widget(metadata, user_key)
|
||||
|
||||
def on_script_will_rerun(self, latest_widget_states: WidgetStatesProto) -> None:
|
||||
self._yield_callback()
|
||||
with self._lock:
|
||||
# TODO: rewrite this to copy the callbacks list into a local
|
||||
# variable so that we don't need to hold our lock for the
|
||||
# duration. (This will also allow us to downgrade our RLock
|
||||
# to a Lock.)
|
||||
self._state.on_script_will_rerun(latest_widget_states)
|
||||
|
||||
def on_script_finished(self, widget_ids_this_run: set[str]) -> None:
|
||||
with self._lock:
|
||||
self._state.on_script_finished(widget_ids_this_run)
|
||||
|
||||
def maybe_check_serializable(self) -> None:
|
||||
with self._lock:
|
||||
self._state.maybe_check_serializable()
|
||||
|
||||
def get_widget_states(self) -> list[WidgetStateProto]:
|
||||
"""Return a list of serialized widget values for each widget with a value."""
|
||||
with self._lock:
|
||||
return self._state.get_widget_states()
|
||||
|
||||
def is_new_state_value(self, user_key: str) -> bool:
|
||||
with self._lock:
|
||||
return self._state.is_new_state_value(user_key)
|
||||
|
||||
@property
|
||||
def filtered_state(self) -> dict[str, Any]:
|
||||
"""The combined session and widget state, excluding keyless widgets."""
|
||||
with self._lock:
|
||||
return self._state.filtered_state
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
self._yield_callback()
|
||||
with self._lock:
|
||||
return self._state[key]
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
self._yield_callback()
|
||||
with self._lock:
|
||||
self._state[key] = value
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
self._yield_callback()
|
||||
with self._lock:
|
||||
del self._state[key]
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
self._yield_callback()
|
||||
with self._lock:
|
||||
return key in self._state
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
raise AttributeError(f"{key} not found in session_state.")
|
||||
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
self[key] = value
|
||||
|
||||
def __delattr__(self, key: str) -> None:
|
||||
try:
|
||||
del self[key]
|
||||
except KeyError:
|
||||
raise AttributeError(f"{key} not found in session_state.")
|
||||
|
||||
def __repr__(self):
|
||||
"""Presents itself as a simple dict of the underlying SessionState instance."""
|
||||
kv = ((k, self._state[k]) for k in self._state._keys())
|
||||
s = ", ".join(f"{k}: {v!r}" for k, v in kv)
|
||||
return f"{{{s}}}"
|
||||
|
||||
@contextmanager
|
||||
def query_params(self) -> Iterator[QueryParams]:
|
||||
self._yield_callback()
|
||||
with self._lock:
|
||||
yield self._state.query_params
|
||||
@@ -0,0 +1,773 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pickle
|
||||
from collections.abc import Iterator, KeysView, MutableMapping
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Final,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import streamlit as st
|
||||
from streamlit import config, util
|
||||
from streamlit.errors import StreamlitAPIException, UnserializableSessionStateError
|
||||
from streamlit.proto.WidgetStates_pb2 import WidgetState as WidgetStateProto
|
||||
from streamlit.proto.WidgetStates_pb2 import WidgetStates as WidgetStatesProto
|
||||
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx
|
||||
from streamlit.runtime.state.common import (
|
||||
RegisterWidgetResult,
|
||||
T,
|
||||
ValueFieldName,
|
||||
WidgetMetadata,
|
||||
is_array_value_field_name,
|
||||
is_element_id,
|
||||
is_keyed_element_id,
|
||||
)
|
||||
from streamlit.runtime.state.query_params import QueryParams
|
||||
from streamlit.runtime.stats import CacheStat, CacheStatsProvider, group_stats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.runtime.session_manager import SessionManager
|
||||
|
||||
|
||||
STREAMLIT_INTERNAL_KEY_PREFIX: Final = "$$STREAMLIT_INTERNAL_KEY"
|
||||
SCRIPT_RUN_WITHOUT_ERRORS_KEY: Final = (
|
||||
f"{STREAMLIT_INTERNAL_KEY_PREFIX}_SCRIPT_RUN_WITHOUT_ERRORS"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Serialized:
|
||||
"""A widget value that's serialized to a protobuf. Immutable."""
|
||||
|
||||
value: WidgetStateProto
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Value:
|
||||
"""A widget value that's not serialized. Immutable."""
|
||||
|
||||
value: Any
|
||||
|
||||
|
||||
WState: TypeAlias = Union[Value, Serialized]
|
||||
|
||||
|
||||
@dataclass
|
||||
class WStates(MutableMapping[str, Any]):
|
||||
"""A mapping of widget IDs to values. Widget values can be stored in
|
||||
serialized or deserialized form, but when values are retrieved from the
|
||||
mapping, they'll always be deserialized.
|
||||
"""
|
||||
|
||||
states: dict[str, WState] = field(default_factory=dict)
|
||||
widget_metadata: dict[str, WidgetMetadata[Any]] = field(default_factory=dict)
|
||||
|
||||
def __repr__(self):
|
||||
return util.repr_(self)
|
||||
|
||||
def __getitem__(self, k: str) -> Any:
|
||||
"""Return the value of the widget with the given key.
|
||||
If the widget's value is currently stored in serialized form, it
|
||||
will be deserialized first.
|
||||
"""
|
||||
wstate = self.states.get(k)
|
||||
if wstate is None:
|
||||
raise KeyError(k)
|
||||
|
||||
if isinstance(wstate, Value):
|
||||
# The widget's value is already deserialized - return it directly.
|
||||
return wstate.value
|
||||
|
||||
# The widget's value is serialized. We deserialize it, and return
|
||||
# the deserialized value.
|
||||
|
||||
metadata = self.widget_metadata.get(k)
|
||||
if metadata is None:
|
||||
# No deserializer, which should only happen if state is
|
||||
# gotten from a reconnecting browser and the script is
|
||||
# trying to access it. Pretend it doesn't exist.
|
||||
raise KeyError(k)
|
||||
value_field_name = cast(
|
||||
"ValueFieldName",
|
||||
wstate.value.WhichOneof("value"),
|
||||
)
|
||||
value = (
|
||||
wstate.value.__getattribute__(value_field_name)
|
||||
if value_field_name # Field name is None if the widget value was cleared
|
||||
else None
|
||||
)
|
||||
|
||||
if is_array_value_field_name(value_field_name):
|
||||
# Array types are messages with data in a `data` field
|
||||
value = value.data
|
||||
elif value_field_name == "json_value":
|
||||
value = json.loads(value)
|
||||
|
||||
deserialized = metadata.deserializer(value, metadata.id)
|
||||
|
||||
# Update metadata to reflect information from WidgetState proto
|
||||
self.set_widget_metadata(
|
||||
replace(
|
||||
metadata,
|
||||
value_type=value_field_name,
|
||||
)
|
||||
)
|
||||
|
||||
self.states[k] = Value(deserialized)
|
||||
return deserialized
|
||||
|
||||
def __setitem__(self, k: str, v: WState) -> None:
|
||||
self.states[k] = v
|
||||
|
||||
def __delitem__(self, k: str) -> None:
|
||||
del self.states[k]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.states)
|
||||
|
||||
def __iter__(self):
|
||||
# For this and many other methods, we can't simply delegate to the
|
||||
# states field, because we need to invoke `__getitem__` for any
|
||||
# values, to handle deserialization and unwrapping of values.
|
||||
yield from self.states
|
||||
|
||||
def keys(self) -> KeysView[str]:
|
||||
return KeysView(self.states)
|
||||
|
||||
def items(self) -> set[tuple[str, Any]]: # type: ignore[override]
|
||||
return {(k, self[k]) for k in self}
|
||||
|
||||
def values(self) -> set[Any]: # type: ignore[override]
|
||||
return {self[wid] for wid in self}
|
||||
|
||||
def update(self, other: WStates) -> None: # type: ignore[override]
|
||||
"""Copy all widget values and metadata from 'other' into this mapping,
|
||||
overwriting any data in this mapping that's also present in 'other'.
|
||||
"""
|
||||
self.states.update(other.states)
|
||||
self.widget_metadata.update(other.widget_metadata)
|
||||
|
||||
def set_widget_from_proto(self, widget_state: WidgetStateProto) -> None:
|
||||
"""Set a widget's serialized value, overwriting any existing value it has."""
|
||||
self[widget_state.id] = Serialized(widget_state)
|
||||
|
||||
def set_from_value(self, k: str, v: Any) -> None:
|
||||
"""Set a widget's deserialized value, overwriting any existing value it has."""
|
||||
self[k] = Value(v)
|
||||
|
||||
def set_widget_metadata(self, widget_meta: WidgetMetadata[Any]) -> None:
|
||||
"""Set a widget's metadata, overwriting any existing metadata it has."""
|
||||
self.widget_metadata[widget_meta.id] = widget_meta
|
||||
|
||||
def remove_stale_widgets(
|
||||
self,
|
||||
active_widget_ids: set[str],
|
||||
fragment_ids_this_run: list[str] | None,
|
||||
) -> None:
|
||||
"""Remove widget state for stale widgets."""
|
||||
self.states = {
|
||||
k: v
|
||||
for k, v in self.states.items()
|
||||
if not _is_stale_widget(
|
||||
self.widget_metadata.get(k),
|
||||
active_widget_ids,
|
||||
fragment_ids_this_run,
|
||||
)
|
||||
}
|
||||
|
||||
def get_serialized(self, k: str) -> WidgetStateProto | None:
|
||||
"""Get the serialized value of the widget with the given id.
|
||||
|
||||
If the widget doesn't exist, return None. If the widget exists but
|
||||
is not in serialized form, it will be serialized first.
|
||||
"""
|
||||
|
||||
item = self.states.get(k)
|
||||
if item is None:
|
||||
# No such widget: return None.
|
||||
return None
|
||||
|
||||
if isinstance(item, Serialized):
|
||||
# Widget value is serialized: return it directly.
|
||||
return item.value
|
||||
|
||||
# Widget value is not serialized: serialize it first!
|
||||
metadata = self.widget_metadata.get(k)
|
||||
if metadata is None:
|
||||
# We're missing the widget's metadata. (Can this happen?)
|
||||
return None
|
||||
|
||||
widget = WidgetStateProto()
|
||||
widget.id = k
|
||||
|
||||
field = metadata.value_type
|
||||
serialized = metadata.serializer(item.value)
|
||||
|
||||
if is_array_value_field_name(field):
|
||||
arr = getattr(widget, field)
|
||||
arr.data.extend(serialized)
|
||||
elif field == "json_value":
|
||||
setattr(widget, field, json.dumps(serialized))
|
||||
elif field == "file_uploader_state_value":
|
||||
widget.file_uploader_state_value.CopyFrom(serialized)
|
||||
elif field == "string_trigger_value":
|
||||
widget.string_trigger_value.CopyFrom(serialized)
|
||||
elif field == "chat_input_value":
|
||||
widget.chat_input_value.CopyFrom(serialized)
|
||||
elif field is not None and serialized is not None:
|
||||
# If the field is None, the widget value was cleared
|
||||
# by the user and therefore is None. But we cannot
|
||||
# set it to None here, since the proto properties are
|
||||
# not nullable. So we just don't set it.
|
||||
setattr(widget, field, serialized)
|
||||
|
||||
return widget
|
||||
|
||||
def as_widget_states(self) -> list[WidgetStateProto]:
|
||||
"""Return a list of serialized widget values for each widget with a value."""
|
||||
states = [
|
||||
self.get_serialized(widget_id)
|
||||
for widget_id in self.states.keys()
|
||||
if self.get_serialized(widget_id)
|
||||
]
|
||||
states = cast("list[WidgetStateProto]", states)
|
||||
return states
|
||||
|
||||
def call_callback(self, widget_id: str) -> None:
|
||||
"""Call the given widget's callback and return the callback's
|
||||
return value. If the widget has no callback, return None.
|
||||
|
||||
If the widget doesn't exist, raise an Exception.
|
||||
"""
|
||||
metadata = self.widget_metadata.get(widget_id)
|
||||
assert metadata is not None
|
||||
callback = metadata.callback
|
||||
if callback is None:
|
||||
return
|
||||
|
||||
args = metadata.callback_args or ()
|
||||
kwargs = metadata.callback_kwargs or {}
|
||||
callback(*args, **kwargs)
|
||||
|
||||
|
||||
def _missing_key_error_message(key: str) -> str:
|
||||
return (
|
||||
f'st.session_state has no key "{key}". Did you forget to initialize it? '
|
||||
f"More info: https://docs.streamlit.io/develop/concepts/architecture/session-state#initialization"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KeyIdMapper:
|
||||
"""A mapping of user-provided keys to element IDs.
|
||||
It also maps element IDs to user-provided keys so that this reverse mapping
|
||||
does not have to be computed ad-hoc.
|
||||
All built-in dict-operations such as setting and deleting expect the key as the
|
||||
argument, not the element ID.
|
||||
"""
|
||||
|
||||
_key_id_mapping: dict[str, str] = field(default_factory=dict)
|
||||
_id_key_mapping: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self._key_id_mapping
|
||||
|
||||
def __setitem__(self, key: str, widget_id: Any) -> None:
|
||||
self._key_id_mapping[key] = widget_id
|
||||
self._id_key_mapping[widget_id] = key
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
self.delete(key)
|
||||
|
||||
@property
|
||||
def id_key_mapping(self) -> dict[str, str]:
|
||||
return self._id_key_mapping
|
||||
|
||||
def set_key_id_mapping(self, key_id_mapping: dict[str, str]) -> None:
|
||||
self._key_id_mapping = key_id_mapping
|
||||
self._id_key_mapping = {v: k for k, v in key_id_mapping.items()}
|
||||
|
||||
def get_id_from_key(self, key: str, default: Any = None) -> str:
|
||||
return self._key_id_mapping.get(key, default)
|
||||
|
||||
def get_key_from_id(self, widget_id: str) -> str:
|
||||
return self._id_key_mapping[widget_id]
|
||||
|
||||
def update(self, other: KeyIdMapper) -> None:
|
||||
self._key_id_mapping.update(other._key_id_mapping)
|
||||
self._id_key_mapping.update(other._id_key_mapping)
|
||||
|
||||
def clear(self):
|
||||
self._key_id_mapping.clear()
|
||||
self._id_key_mapping.clear()
|
||||
|
||||
def delete(self, key: str):
|
||||
widget_id = self._key_id_mapping[key]
|
||||
del self._key_id_mapping[key]
|
||||
del self._id_key_mapping[widget_id]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionState:
|
||||
"""SessionState allows users to store values that persist between app
|
||||
reruns.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> if "num_script_runs" not in st.session_state:
|
||||
... st.session_state.num_script_runs = 0
|
||||
>>> st.session_state.num_script_runs += 1
|
||||
>>> st.write(st.session_state.num_script_runs) # writes 1
|
||||
|
||||
The next time your script runs, the value of
|
||||
st.session_state.num_script_runs will be preserved.
|
||||
>>> st.session_state.num_script_runs += 1
|
||||
>>> st.write(st.session_state.num_script_runs) # writes 2
|
||||
"""
|
||||
|
||||
# All the values from previous script runs, squished together to save memory
|
||||
_old_state: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Values set in session state during the current script run, possibly for
|
||||
# setting a widget's value. Keyed by a user provided string.
|
||||
_new_session_state: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Widget values from the frontend, usually one changing prompted the script rerun
|
||||
_new_widget_state: WStates = field(default_factory=WStates)
|
||||
|
||||
# Keys used for widgets will be eagerly converted to the matching element id
|
||||
_key_id_mapper: KeyIdMapper = field(default_factory=KeyIdMapper)
|
||||
|
||||
# query params are stored in session state because query params will be tied with
|
||||
# widget state at one point.
|
||||
query_params: QueryParams = field(default_factory=QueryParams)
|
||||
|
||||
def __repr__(self):
|
||||
return util.repr_(self)
|
||||
|
||||
# is it possible for a value to get through this without being deserialized?
|
||||
def _compact_state(self) -> None:
|
||||
"""Copy all current session_state and widget_state values into our
|
||||
_old_state dict, and then clear our current session_state and
|
||||
widget_state.
|
||||
"""
|
||||
for key_or_wid in self:
|
||||
try:
|
||||
self._old_state[key_or_wid] = self[key_or_wid]
|
||||
except KeyError:
|
||||
# handle key errors from widget state not having metadata gracefully
|
||||
# https://github.com/streamlit/streamlit/issues/7206
|
||||
pass
|
||||
self._new_session_state.clear()
|
||||
self._new_widget_state.clear()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Reset self completely, clearing all current and old values."""
|
||||
self._old_state.clear()
|
||||
self._new_session_state.clear()
|
||||
self._new_widget_state.clear()
|
||||
self._key_id_mapper.clear()
|
||||
|
||||
@property
|
||||
def filtered_state(self) -> dict[str, Any]:
|
||||
"""The combined session and widget state, excluding keyless widgets."""
|
||||
|
||||
wid_key_map = self._key_id_mapper.id_key_mapping
|
||||
|
||||
state: dict[str, Any] = {}
|
||||
|
||||
# We can't write `for k, v in self.items()` here because doing so will
|
||||
# run into a `KeyError` if widget metadata has been cleared (which
|
||||
# happens when the streamlit server restarted or the cache was cleared),
|
||||
# then we receive a widget's state from a browser.
|
||||
for k in self._keys():
|
||||
if not is_element_id(k) and not _is_internal_key(k):
|
||||
state[k] = self[k]
|
||||
elif is_keyed_element_id(k):
|
||||
try:
|
||||
key = wid_key_map[k]
|
||||
state[key] = self[k]
|
||||
except KeyError:
|
||||
# Widget id no longer maps to a key, it is a not yet
|
||||
# cleared value in old state for a reset widget
|
||||
pass
|
||||
|
||||
return state
|
||||
|
||||
def _keys(self) -> set[str]:
|
||||
"""All keys active in Session State, with widget keys converted
|
||||
to widget ids when one is known. (This includes autogenerated keys
|
||||
for widgets that don't have user_keys defined, and which aren't
|
||||
exposed to user code).
|
||||
"""
|
||||
old_keys = {self._get_widget_id(k) for k in self._old_state.keys()}
|
||||
new_widget_keys = set(self._new_widget_state.keys())
|
||||
new_session_state_keys = {
|
||||
self._get_widget_id(k) for k in self._new_session_state.keys()
|
||||
}
|
||||
return old_keys | new_widget_keys | new_session_state_keys
|
||||
|
||||
def is_new_state_value(self, user_key: str) -> bool:
|
||||
"""True if a value with the given key is in the current session state."""
|
||||
return user_key in self._new_session_state
|
||||
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
"""Return an iterator over the keys of the SessionState.
|
||||
This is a shortcut for `iter(self.keys())`.
|
||||
"""
|
||||
return iter(self._keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of items in SessionState."""
|
||||
return len(self._keys())
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
wid_key_map = self._key_id_mapper.id_key_mapping
|
||||
widget_id = self._get_widget_id(key)
|
||||
|
||||
if widget_id in wid_key_map and widget_id == key:
|
||||
# the "key" is a raw widget id, so get its associated user key for lookup
|
||||
key = wid_key_map[widget_id]
|
||||
try:
|
||||
return self._getitem(widget_id, key)
|
||||
except KeyError:
|
||||
raise KeyError(_missing_key_error_message(key))
|
||||
|
||||
def _getitem(self, widget_id: str | None, user_key: str | None) -> Any:
|
||||
"""Get the value of an entry in Session State, using either the
|
||||
user-provided key or a widget id as appropriate for the internal dict
|
||||
being accessed.
|
||||
|
||||
At least one of the arguments must have a value.
|
||||
"""
|
||||
assert user_key is not None or widget_id is not None
|
||||
|
||||
if user_key is not None:
|
||||
try:
|
||||
return self._new_session_state[user_key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if widget_id is not None:
|
||||
try:
|
||||
return self._new_widget_state[widget_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# Typically, there won't be both a widget id and an associated state key in
|
||||
# old state at the same time, so the order we check is arbitrary.
|
||||
# The exception is if session state is set and then a later run has
|
||||
# a widget created, so the widget id entry should be newer.
|
||||
# The opposite case shouldn't happen, because setting the value of a widget
|
||||
# through session state will result in the next widget state reflecting that
|
||||
# value.
|
||||
if widget_id is not None:
|
||||
try:
|
||||
return self._old_state[widget_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if user_key is not None:
|
||||
try:
|
||||
return self._old_state[user_key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# We'll never get here
|
||||
raise KeyError
|
||||
|
||||
def __setitem__(self, user_key: str, value: Any) -> None:
|
||||
"""Set the value of the session_state entry with the given user_key.
|
||||
|
||||
If the key corresponds to a widget or form that's been instantiated
|
||||
during the current script run, raise a StreamlitAPIException instead.
|
||||
"""
|
||||
ctx = get_script_run_ctx()
|
||||
|
||||
if ctx is not None:
|
||||
widget_id = self._key_id_mapper.get_id_from_key(user_key, None)
|
||||
widget_ids = ctx.widget_ids_this_run
|
||||
form_ids = ctx.form_ids_this_run
|
||||
|
||||
if widget_id in widget_ids or user_key in form_ids:
|
||||
raise StreamlitAPIException(
|
||||
f"`st.session_state.{user_key}` cannot be modified after the widget"
|
||||
f" with key `{user_key}` is instantiated."
|
||||
)
|
||||
|
||||
self._new_session_state[user_key] = value
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
widget_id = self._get_widget_id(key)
|
||||
|
||||
if not (key in self or widget_id in self):
|
||||
raise KeyError(_missing_key_error_message(key))
|
||||
|
||||
if key in self._new_session_state:
|
||||
del self._new_session_state[key]
|
||||
|
||||
if key in self._old_state:
|
||||
del self._old_state[key]
|
||||
|
||||
if key in self._key_id_mapper:
|
||||
self._key_id_mapper.delete(key)
|
||||
|
||||
if widget_id in self._new_widget_state:
|
||||
del self._new_widget_state[widget_id]
|
||||
|
||||
if widget_id in self._old_state:
|
||||
del self._old_state[widget_id]
|
||||
|
||||
def set_widgets_from_proto(self, widget_states: WidgetStatesProto) -> None:
|
||||
"""Set the value of all widgets represented in the given WidgetStatesProto."""
|
||||
for state in widget_states.widgets:
|
||||
self._new_widget_state.set_widget_from_proto(state)
|
||||
|
||||
def on_script_will_rerun(self, latest_widget_states: WidgetStatesProto) -> None:
|
||||
"""Called by ScriptRunner before its script re-runs.
|
||||
|
||||
Update widget data and call callbacks on widgets whose value changed
|
||||
between the previous and current script runs.
|
||||
"""
|
||||
# Clear any triggers that weren't reset because the script was disconnected
|
||||
self._reset_triggers()
|
||||
self._compact_state()
|
||||
self.set_widgets_from_proto(latest_widget_states)
|
||||
self._call_callbacks()
|
||||
|
||||
def _call_callbacks(self) -> None:
|
||||
"""Call any callback associated with each widget whose value
|
||||
changed between the previous and current script runs.
|
||||
"""
|
||||
from streamlit.runtime.scriptrunner import RerunException
|
||||
|
||||
changed_widget_ids = [
|
||||
wid for wid in self._new_widget_state if self._widget_changed(wid)
|
||||
]
|
||||
for wid in changed_widget_ids:
|
||||
try:
|
||||
self._new_widget_state.call_callback(wid)
|
||||
except RerunException:
|
||||
st.warning("Calling st.rerun() within a callback is a no-op.")
|
||||
|
||||
def _widget_changed(self, widget_id: str) -> bool:
|
||||
"""True if the given widget's value changed between the previous
|
||||
script run and the current script run.
|
||||
"""
|
||||
new_value = self._new_widget_state.get(widget_id)
|
||||
old_value = self._old_state.get(widget_id)
|
||||
changed: bool = new_value != old_value
|
||||
return changed
|
||||
|
||||
def on_script_finished(self, widget_ids_this_run: set[str]) -> None:
|
||||
"""Called by ScriptRunner after its script finishes running.
|
||||
Updates widgets to prepare for the next script run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
widget_ids_this_run: set[str]
|
||||
The IDs of the widgets that were accessed during the script
|
||||
run. Any widget state whose ID does *not* appear in this set
|
||||
is considered "stale" and will be removed.
|
||||
"""
|
||||
self._reset_triggers()
|
||||
self._remove_stale_widgets(widget_ids_this_run)
|
||||
|
||||
def _reset_triggers(self) -> None:
|
||||
"""Set all trigger values in our state dictionary to False."""
|
||||
for state_id in self._new_widget_state:
|
||||
metadata = self._new_widget_state.widget_metadata.get(state_id)
|
||||
if metadata is not None:
|
||||
if metadata.value_type == "trigger_value":
|
||||
self._new_widget_state[state_id] = Value(False)
|
||||
elif metadata.value_type == "string_trigger_value":
|
||||
self._new_widget_state[state_id] = Value(None)
|
||||
elif metadata.value_type == "chat_input_value":
|
||||
self._new_widget_state[state_id] = Value(None)
|
||||
|
||||
for state_id in self._old_state:
|
||||
metadata = self._new_widget_state.widget_metadata.get(state_id)
|
||||
if metadata is not None:
|
||||
if metadata.value_type == "trigger_value":
|
||||
self._old_state[state_id] = False
|
||||
elif metadata.value_type == "string_trigger_value":
|
||||
self._old_state[state_id] = None
|
||||
elif metadata.value_type == "chat_input_value":
|
||||
self._old_state[state_id] = None
|
||||
|
||||
def _remove_stale_widgets(self, active_widget_ids: set[str]) -> None:
|
||||
"""Remove widget state for widgets whose ids aren't in `active_widget_ids`."""
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is None:
|
||||
return
|
||||
|
||||
self._new_widget_state.remove_stale_widgets(
|
||||
active_widget_ids,
|
||||
ctx.fragment_ids_this_run,
|
||||
)
|
||||
|
||||
# Remove entries from _old_state corresponding to
|
||||
# widgets not in widget_ids.
|
||||
self._old_state = {
|
||||
k: v
|
||||
for k, v in self._old_state.items()
|
||||
if (
|
||||
not is_element_id(k)
|
||||
or not _is_stale_widget(
|
||||
self._new_widget_state.widget_metadata.get(k),
|
||||
active_widget_ids,
|
||||
ctx.fragment_ids_this_run,
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
def _set_widget_metadata(self, widget_metadata: WidgetMetadata[Any]) -> None:
|
||||
"""Set a widget's metadata."""
|
||||
widget_id = widget_metadata.id
|
||||
self._new_widget_state.widget_metadata[widget_id] = widget_metadata
|
||||
|
||||
def get_widget_states(self) -> list[WidgetStateProto]:
|
||||
"""Return a list of serialized widget values for each widget with a value."""
|
||||
return self._new_widget_state.as_widget_states()
|
||||
|
||||
def _get_widget_id(self, k: str) -> str:
|
||||
"""Turns a value that might be a widget id or a user provided key into
|
||||
an appropriate widget id.
|
||||
"""
|
||||
return self._key_id_mapper.get_id_from_key(k, k)
|
||||
|
||||
def _set_key_widget_mapping(self, widget_id: str, user_key: str) -> None:
|
||||
self._key_id_mapper[user_key] = widget_id
|
||||
|
||||
def register_widget(
|
||||
self, metadata: WidgetMetadata[T], user_key: str | None
|
||||
) -> RegisterWidgetResult[T]:
|
||||
"""Register a widget with the SessionState.
|
||||
|
||||
Returns
|
||||
-------
|
||||
RegisterWidgetResult[T]
|
||||
Contains the widget's current value, and a bool that will be True
|
||||
if the frontend needs to be updated with the current value.
|
||||
"""
|
||||
widget_id = metadata.id
|
||||
|
||||
self._set_widget_metadata(metadata)
|
||||
if user_key is not None:
|
||||
# If the widget has a user_key, update its user_key:widget_id mapping
|
||||
self._set_key_widget_mapping(widget_id, user_key)
|
||||
|
||||
if widget_id not in self and (user_key is None or user_key not in self):
|
||||
# This is the first time the widget is registered, so we save its
|
||||
# value in widget state.
|
||||
deserializer = metadata.deserializer
|
||||
initial_widget_value = deepcopy(deserializer(None, metadata.id))
|
||||
self._new_widget_state.set_from_value(widget_id, initial_widget_value)
|
||||
|
||||
# Get the current value of the widget for use as its return value.
|
||||
# We return a copy, so that reference types can't be accidentally
|
||||
# mutated by user code.
|
||||
widget_value = cast("T", self[widget_id])
|
||||
widget_value = deepcopy(widget_value)
|
||||
|
||||
# widget_value_changed indicates to the caller that the widget's
|
||||
# current value is different from what is in the frontend.
|
||||
widget_value_changed = user_key is not None and self.is_new_state_value(
|
||||
user_key
|
||||
)
|
||||
|
||||
return RegisterWidgetResult(widget_value, widget_value_changed)
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
try:
|
||||
self[key]
|
||||
except KeyError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
# Lazy-load vendored package to prevent import of numpy
|
||||
from streamlit.vendor.pympler.asizeof import asizeof
|
||||
|
||||
stat = CacheStat("st_session_state", "", asizeof(self))
|
||||
return [stat]
|
||||
|
||||
def _check_serializable(self) -> None:
|
||||
"""Verify that everything added to session state can be serialized.
|
||||
We use pickleability as the metric for serializability, and test for
|
||||
pickleability by just trying it.
|
||||
"""
|
||||
for k in self:
|
||||
try:
|
||||
pickle.dumps(self[k])
|
||||
except Exception as e:
|
||||
err_msg = f"""Cannot serialize the value (of type `{type(self[k])}`) of '{k}' in st.session_state.
|
||||
Streamlit has been configured to use [pickle](https://docs.python.org/3/library/pickle.html) to
|
||||
serialize session_state values. Please convert the value to a pickle-serializable type. To learn
|
||||
more about this behavior, see [our docs](https://docs.streamlit.io/knowledge-base/using-streamlit/serializable-session-state). """
|
||||
raise UnserializableSessionStateError(err_msg) from e
|
||||
|
||||
def maybe_check_serializable(self) -> None:
|
||||
"""Verify that session state can be serialized, if the relevant config
|
||||
option is set.
|
||||
|
||||
See `_check_serializable` for details.
|
||||
"""
|
||||
if config.get_option("runner.enforceSerializableSessionState"):
|
||||
self._check_serializable()
|
||||
|
||||
|
||||
def _is_internal_key(key: str) -> bool:
|
||||
return key.startswith(STREAMLIT_INTERNAL_KEY_PREFIX)
|
||||
|
||||
|
||||
def _is_stale_widget(
|
||||
metadata: WidgetMetadata[Any] | None,
|
||||
active_widget_ids: set[str],
|
||||
fragment_ids_this_run: list[str] | None,
|
||||
) -> bool:
|
||||
if not metadata:
|
||||
return True
|
||||
elif metadata.id in active_widget_ids:
|
||||
return False
|
||||
# If we're running 1 or more fragments, but this widget is unrelated to any of the
|
||||
# fragments that we're running, then it should not be marked as stale as its value
|
||||
# may still be needed for a future fragment run or full script run.
|
||||
elif fragment_ids_this_run and metadata.fragment_id not in fragment_ids_this_run:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionStateStatProvider(CacheStatsProvider):
|
||||
_session_mgr: SessionManager
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
stats: list[CacheStat] = []
|
||||
for session_info in self._session_mgr.list_active_sessions():
|
||||
session_state = session_info.session.session_state
|
||||
stats.extend(session_state.get_stats())
|
||||
return group_stats(stats)
|
||||
@@ -0,0 +1,153 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, MutableMapping
|
||||
from typing import Any, Final
|
||||
|
||||
from streamlit import logger as _logger
|
||||
from streamlit import runtime
|
||||
from streamlit.elements.lib.utils import Key
|
||||
from streamlit.runtime.metrics_util import gather_metrics
|
||||
from streamlit.runtime.state.common import require_valid_user_key
|
||||
from streamlit.runtime.state.safe_session_state import SafeSessionState
|
||||
from streamlit.runtime.state.session_state import SessionState
|
||||
|
||||
_LOGGER: Final = _logger.get_logger(__name__)
|
||||
|
||||
|
||||
_state_use_warning_already_displayed: bool = False
|
||||
# The mock session state is used as a fallback if the script is run without `streamlit run`
|
||||
_mock_session_state: SafeSessionState | None = None
|
||||
|
||||
|
||||
def get_session_state() -> SafeSessionState:
|
||||
"""Get the SessionState object for the current session.
|
||||
|
||||
Note that in streamlit scripts, this function should not be called
|
||||
directly. Instead, SessionState objects should be accessed via
|
||||
st.session_state.
|
||||
"""
|
||||
global _state_use_warning_already_displayed
|
||||
from streamlit.runtime.scriptrunner_utils.script_run_context import (
|
||||
get_script_run_ctx,
|
||||
)
|
||||
|
||||
ctx = get_script_run_ctx()
|
||||
|
||||
# If there is no script run context because the script is run bare, we
|
||||
# use a global mock session state version to allow bare script execution (via python script.py)
|
||||
if ctx is None:
|
||||
if not _state_use_warning_already_displayed:
|
||||
_state_use_warning_already_displayed = True
|
||||
if not runtime.exists():
|
||||
_LOGGER.warning(
|
||||
"Session state does not function when running a script without `streamlit run`"
|
||||
)
|
||||
|
||||
global _mock_session_state
|
||||
|
||||
if _mock_session_state is None:
|
||||
# Lazy initialize the mock session state
|
||||
_mock_session_state = SafeSessionState(SessionState(), lambda: None)
|
||||
return _mock_session_state
|
||||
return ctx.session_state
|
||||
|
||||
|
||||
class SessionStateProxy(MutableMapping[Key, Any]):
|
||||
"""A stateless singleton that proxies `st.session_state` interactions
|
||||
to the current script thread's SessionState instance.
|
||||
|
||||
The proxy API differs slightly from SessionState: it does not allow
|
||||
callers to get, set, or iterate over "keyless" widgets (that is, widgets
|
||||
that were created without a user_key, and have autogenerated keys).
|
||||
"""
|
||||
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
"""Iterator over user state and keyed widget values."""
|
||||
# TODO: this is unsafe if fastReruns is true! Let's deprecate/remove.
|
||||
return iter(get_session_state().filtered_state)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Number of user state and keyed widget values in session_state."""
|
||||
return len(get_session_state().filtered_state)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of user state and keyed widget values."""
|
||||
return str(get_session_state().filtered_state)
|
||||
|
||||
def __getitem__(self, key: Key) -> Any:
|
||||
"""Return the state or widget value with the given key.
|
||||
|
||||
Raises
|
||||
------
|
||||
StreamlitAPIException
|
||||
If the key is not a valid SessionState user key.
|
||||
"""
|
||||
key = str(key)
|
||||
require_valid_user_key(key)
|
||||
return get_session_state()[key]
|
||||
|
||||
@gather_metrics("session_state.set_item")
|
||||
def __setitem__(self, key: Key, value: Any) -> None:
|
||||
"""Set the value of the given key.
|
||||
|
||||
Raises
|
||||
------
|
||||
StreamlitAPIException
|
||||
If the key is not a valid SessionState user key.
|
||||
"""
|
||||
key = str(key)
|
||||
require_valid_user_key(key)
|
||||
get_session_state()[key] = value
|
||||
|
||||
def __delitem__(self, key: Key) -> None:
|
||||
"""Delete the value with the given key.
|
||||
|
||||
Raises
|
||||
------
|
||||
StreamlitAPIException
|
||||
If the key is not a valid SessionState user key.
|
||||
"""
|
||||
key = str(key)
|
||||
require_valid_user_key(key)
|
||||
del get_session_state()[key]
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
raise AttributeError(_missing_attr_error_message(key))
|
||||
|
||||
@gather_metrics("session_state.set_attr")
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
self[key] = value
|
||||
|
||||
def __delattr__(self, key: str) -> None:
|
||||
try:
|
||||
del self[key]
|
||||
except KeyError:
|
||||
raise AttributeError(_missing_attr_error_message(key))
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return a dict containing all session_state and keyed widget values."""
|
||||
return get_session_state().filtered_state
|
||||
|
||||
|
||||
def _missing_attr_error_message(attr_name: str) -> str:
|
||||
return (
|
||||
f'st.session_state has no attribute "{attr_name}". Did you forget to initialize it? '
|
||||
f"More info: https://docs.streamlit.io/develop/concepts/architecture/session-state#initialization"
|
||||
)
|
||||
@@ -0,0 +1,135 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from streamlit.runtime.state.common import (
|
||||
RegisterWidgetResult,
|
||||
T,
|
||||
ValueFieldName,
|
||||
WidgetArgs,
|
||||
WidgetCallback,
|
||||
WidgetDeserializer,
|
||||
WidgetKwargs,
|
||||
WidgetMetadata,
|
||||
WidgetSerializer,
|
||||
user_key_from_element_id,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.runtime.scriptrunner import ScriptRunContext
|
||||
|
||||
|
||||
def register_widget(
|
||||
element_id: str,
|
||||
*,
|
||||
deserializer: WidgetDeserializer[T],
|
||||
serializer: WidgetSerializer[T],
|
||||
ctx: ScriptRunContext | None,
|
||||
on_change_handler: WidgetCallback | None = None,
|
||||
args: WidgetArgs | None = None,
|
||||
kwargs: WidgetKwargs | None = None,
|
||||
value_type: ValueFieldName,
|
||||
) -> RegisterWidgetResult[T]:
|
||||
"""Register a widget with Streamlit, and return its current value.
|
||||
NOTE: This function should be called after the proto has been filled.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
element_id : str
|
||||
The id of the element. Must be unique.
|
||||
deserializer : WidgetDeserializer[T]
|
||||
Called to convert a widget's protobuf value to the value returned by
|
||||
its st.<widget_name> function.
|
||||
serializer : WidgetSerializer[T]
|
||||
Called to convert a widget's value to its protobuf representation.
|
||||
ctx : ScriptRunContext or None
|
||||
Used to ensure uniqueness of widget IDs, and to look up widget values.
|
||||
on_change_handler : WidgetCallback or None
|
||||
An optional callback invoked when the widget's value changes.
|
||||
args : WidgetArgs or None
|
||||
args to pass to on_change_handler when invoked
|
||||
kwargs : WidgetKwargs or None
|
||||
kwargs to pass to on_change_handler when invoked
|
||||
value_type: ValueType
|
||||
The value_type the widget is going to use.
|
||||
We use this information to start with a best-effort guess for the value_type
|
||||
of each widget. Once we actually receive a proto for a widget from the
|
||||
frontend, the guess is updated to be the correct type. Unfortunately, we're
|
||||
not able to always rely on the proto as the type may be needed earlier.
|
||||
Thankfully, in these cases (when value_type == "trigger_value"), the static
|
||||
table here being slightly inaccurate should never pose a problem.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
register_widget_result : RegisterWidgetResult[T]
|
||||
Provides information on which value to return to the widget caller,
|
||||
and whether the UI needs updating.
|
||||
|
||||
- Unhappy path:
|
||||
- Our ScriptRunContext doesn't exist (meaning that we're running
|
||||
as a "bare script" outside streamlit).
|
||||
- We are disconnected from the SessionState instance.
|
||||
In both cases we'll return a fallback RegisterWidgetResult[T].
|
||||
- Happy path:
|
||||
- The widget has already been registered on a previous run but the
|
||||
user hasn't interacted with it on the client. The widget will have
|
||||
the default value it was first created with. We then return a
|
||||
RegisterWidgetResult[T], containing this value.
|
||||
- The widget has already been registered and the user *has*
|
||||
interacted with it. The widget will have that most recent
|
||||
user-specified value. We then return a RegisterWidgetResult[T],
|
||||
containing this value.
|
||||
|
||||
For both paths a widget return value is provided, allowing the widgets
|
||||
to be used in a non-streamlit setting.
|
||||
"""
|
||||
# Create the widget's updated metadata, and register it with session_state.
|
||||
metadata = WidgetMetadata(
|
||||
element_id,
|
||||
deserializer,
|
||||
serializer,
|
||||
value_type=value_type,
|
||||
callback=on_change_handler,
|
||||
callback_args=args,
|
||||
callback_kwargs=kwargs,
|
||||
fragment_id=ctx.current_fragment_id if ctx else None,
|
||||
)
|
||||
return register_widget_from_metadata(metadata, ctx)
|
||||
|
||||
|
||||
def register_widget_from_metadata(
|
||||
metadata: WidgetMetadata[T],
|
||||
ctx: ScriptRunContext | None,
|
||||
) -> RegisterWidgetResult[T]:
|
||||
"""Register a widget and return its value, using an already constructed
|
||||
`WidgetMetadata`.
|
||||
|
||||
This is split out from `register_widget` to allow caching code to replay
|
||||
widgets by saving and reusing the completed metadata.
|
||||
|
||||
See `register_widget` for details on what this returns.
|
||||
"""
|
||||
if ctx is None:
|
||||
# Early-out if we don't have a script run context (which probably means
|
||||
# we're running as a "bare" Python script, and not via `streamlit run`).
|
||||
return RegisterWidgetResult.failure(deserializer=metadata.deserializer)
|
||||
|
||||
widget_id = metadata.id
|
||||
user_key = user_key_from_element_id(widget_id)
|
||||
|
||||
return ctx.session_state.register_widget(metadata, user_key)
|
||||
Reference in New Issue
Block a user