Mise à jour de Monitor.py et autres scripts
This commit is contained in:
13
myenv/lib/python3.11/site-packages/streamlit/web/__init__.py
Normal file
13
myenv/lib/python3.11/site-packages/streamlit/web/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
355
myenv/lib/python3.11/site-packages/streamlit/web/bootstrap.py
Normal file
355
myenv/lib/python3.11/site-packages/streamlit/web/bootstrap.py
Normal file
@@ -0,0 +1,355 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from typing import Any, Final
|
||||
|
||||
from streamlit import cli_util, config, env_util, file_util, net_util, secrets
|
||||
from streamlit.config import CONFIG_FILENAMES
|
||||
from streamlit.git_util import MIN_GIT_VERSION, GitRepo
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.watcher import report_watchdog_availability, watch_file
|
||||
from streamlit.web.server import Server, server_address_is_unix_socket, server_util
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
|
||||
# The maximum possible total size of a static directory.
|
||||
# We agreed on these limitations for the initial release of static file sharing,
|
||||
# based on security concerns from the SiS and Community Cloud teams
|
||||
MAX_APP_STATIC_FOLDER_SIZE = 1 * 1024 * 1024 * 1024 # 1 GB
|
||||
|
||||
|
||||
def _set_up_signal_handler(server: Server) -> None:
|
||||
_LOGGER.debug("Setting up signal handler")
|
||||
|
||||
def signal_handler(signal_number, stack_frame):
|
||||
# The server will shut down its threads and exit its loop.
|
||||
server.stop()
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
if sys.platform == "win32":
|
||||
signal.signal(signal.SIGBREAK, signal_handler)
|
||||
else:
|
||||
signal.signal(signal.SIGQUIT, signal_handler)
|
||||
|
||||
|
||||
def _fix_sys_path(main_script_path: str) -> None:
|
||||
"""Add the script's folder to the sys path.
|
||||
|
||||
Python normally does this automatically, but since we exec the script
|
||||
ourselves we need to do it instead.
|
||||
"""
|
||||
sys.path.insert(0, os.path.dirname(main_script_path))
|
||||
|
||||
|
||||
def _fix_tornado_crash() -> None:
|
||||
"""Set default asyncio policy to be compatible with Tornado 6.
|
||||
|
||||
Tornado 6 (at least) is not compatible with the default
|
||||
asyncio implementation on Windows. So here we
|
||||
pick the older SelectorEventLoopPolicy when the OS is Windows
|
||||
if the known-incompatible default policy is in use.
|
||||
|
||||
This has to happen as early as possible to make it a low priority and
|
||||
overridable
|
||||
|
||||
See: https://github.com/tornadoweb/tornado/issues/2608
|
||||
|
||||
FIXME: if/when tornado supports the defaults in asyncio,
|
||||
remove and bump tornado requirement for py38
|
||||
"""
|
||||
if env_util.IS_WINDOWS:
|
||||
try:
|
||||
from asyncio import ( # type: ignore[attr-defined]
|
||||
WindowsProactorEventLoopPolicy,
|
||||
WindowsSelectorEventLoopPolicy,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
# Not affected
|
||||
else:
|
||||
if type(asyncio.get_event_loop_policy()) is WindowsProactorEventLoopPolicy:
|
||||
# WindowsProactorEventLoopPolicy is not compatible with
|
||||
# Tornado 6 fallback to the pre-3.8 default of Selector
|
||||
asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
|
||||
|
||||
|
||||
def _fix_sys_argv(main_script_path: str, args: list[str]) -> None:
|
||||
"""sys.argv needs to exclude streamlit arguments and parameters
|
||||
and be set to what a user's script may expect.
|
||||
"""
|
||||
import sys
|
||||
|
||||
sys.argv = [main_script_path] + list(args)
|
||||
|
||||
|
||||
def _on_server_start(server: Server) -> None:
|
||||
_maybe_print_old_git_warning(server.main_script_path)
|
||||
_maybe_print_static_folder_warning(server.main_script_path)
|
||||
_print_url(server.is_running_hello)
|
||||
report_watchdog_availability()
|
||||
|
||||
# Load secrets.toml if it exists. If the file doesn't exist, this
|
||||
# function will return without raising an exception. We catch any parse
|
||||
# errors and display them here.
|
||||
try:
|
||||
secrets.load_if_toml_exists()
|
||||
except Exception as ex:
|
||||
_LOGGER.error("Failed to load secrets.toml file", exc_info=ex)
|
||||
|
||||
def maybe_open_browser():
|
||||
if config.get_option("server.headless"):
|
||||
# Don't open browser when in headless mode.
|
||||
return
|
||||
|
||||
if config.is_manually_set("browser.serverAddress"):
|
||||
addr = config.get_option("browser.serverAddress")
|
||||
elif config.is_manually_set("server.address"):
|
||||
if server_address_is_unix_socket():
|
||||
# Don't open browser when server address is an unix socket
|
||||
return
|
||||
addr = config.get_option("server.address")
|
||||
else:
|
||||
addr = "localhost"
|
||||
|
||||
cli_util.open_browser(server_util.get_url(addr))
|
||||
|
||||
# Schedule the browser to open on the main thread.
|
||||
asyncio.get_running_loop().call_soon(maybe_open_browser)
|
||||
|
||||
|
||||
def _fix_pydeck_mapbox_api_warning() -> None:
|
||||
"""Sets MAPBOX_API_KEY environment variable needed for PyDeck otherwise it
|
||||
will throw an exception.
|
||||
"""
|
||||
|
||||
os.environ["MAPBOX_API_KEY"] = config.get_option("mapbox.token")
|
||||
|
||||
|
||||
def _maybe_print_static_folder_warning(main_script_path: str) -> None:
|
||||
"""Prints a warning if the static folder is misconfigured."""
|
||||
|
||||
if config.get_option("server.enableStaticServing"):
|
||||
static_folder_path = file_util.get_app_static_dir(main_script_path)
|
||||
if not os.path.isdir(static_folder_path):
|
||||
cli_util.print_to_cli(
|
||||
f"WARNING: Static file serving is enabled, but no static folder found "
|
||||
f"at {static_folder_path}. To disable static file serving, "
|
||||
f"set server.enableStaticServing to false.",
|
||||
fg="yellow",
|
||||
)
|
||||
else:
|
||||
# Raise warning when static folder size is larger than 1 GB
|
||||
static_folder_size = file_util.get_directory_size(static_folder_path)
|
||||
|
||||
if static_folder_size > MAX_APP_STATIC_FOLDER_SIZE:
|
||||
config.set_option("server.enableStaticServing", False)
|
||||
cli_util.print_to_cli(
|
||||
"WARNING: Static folder size is larger than 1GB. "
|
||||
"Static file serving has been disabled.",
|
||||
fg="yellow",
|
||||
)
|
||||
|
||||
|
||||
def _print_url(is_running_hello: bool) -> None:
|
||||
if is_running_hello:
|
||||
title_message = "Welcome to Streamlit. Check out our demo in your browser."
|
||||
else:
|
||||
title_message = "You can now view your Streamlit app in your browser."
|
||||
|
||||
named_urls = []
|
||||
|
||||
if config.is_manually_set("browser.serverAddress"):
|
||||
named_urls = [
|
||||
("URL", server_util.get_url(config.get_option("browser.serverAddress")))
|
||||
]
|
||||
|
||||
elif (
|
||||
config.is_manually_set("server.address") and not server_address_is_unix_socket()
|
||||
):
|
||||
named_urls = [
|
||||
("URL", server_util.get_url(config.get_option("server.address"))),
|
||||
]
|
||||
|
||||
elif server_address_is_unix_socket():
|
||||
named_urls = [
|
||||
("Unix Socket", config.get_option("server.address")),
|
||||
]
|
||||
|
||||
else:
|
||||
named_urls = [
|
||||
("Local URL", server_util.get_url("localhost")),
|
||||
]
|
||||
|
||||
internal_ip = net_util.get_internal_ip()
|
||||
if internal_ip:
|
||||
named_urls.append(("Network URL", server_util.get_url(internal_ip)))
|
||||
|
||||
if config.get_option("server.headless"):
|
||||
external_ip = net_util.get_external_ip()
|
||||
if external_ip:
|
||||
named_urls.append(("External URL", server_util.get_url(external_ip)))
|
||||
|
||||
cli_util.print_to_cli("")
|
||||
cli_util.print_to_cli(" %s" % title_message, fg="blue", bold=True)
|
||||
cli_util.print_to_cli("")
|
||||
|
||||
for url_name, url in named_urls:
|
||||
cli_util.print_to_cli(f" {url_name}: ", nl=False, fg="blue")
|
||||
cli_util.print_to_cli(url, bold=True)
|
||||
|
||||
cli_util.print_to_cli("")
|
||||
|
||||
if is_running_hello:
|
||||
cli_util.print_to_cli(" Ready to create your own Python apps super quickly?")
|
||||
cli_util.print_to_cli(" Head over to ", nl=False)
|
||||
cli_util.print_to_cli("https://docs.streamlit.io", bold=True)
|
||||
cli_util.print_to_cli("")
|
||||
cli_util.print_to_cli(" May you create awesome apps!")
|
||||
cli_util.print_to_cli("")
|
||||
cli_util.print_to_cli("")
|
||||
|
||||
|
||||
def _maybe_print_old_git_warning(main_script_path: str) -> None:
|
||||
"""If our script is running in a Git repo, and we're running a very old
|
||||
Git version, print a warning that Git integration will be unavailable.
|
||||
"""
|
||||
repo = GitRepo(main_script_path)
|
||||
if (
|
||||
not repo.is_valid()
|
||||
and repo.git_version is not None
|
||||
and repo.git_version < MIN_GIT_VERSION
|
||||
):
|
||||
git_version_string = ".".join(str(val) for val in repo.git_version)
|
||||
min_version_string = ".".join(str(val) for val in MIN_GIT_VERSION)
|
||||
cli_util.print_to_cli("")
|
||||
cli_util.print_to_cli(" Git integration is disabled.", fg="yellow", bold=True)
|
||||
cli_util.print_to_cli("")
|
||||
cli_util.print_to_cli(
|
||||
f" Streamlit requires Git {min_version_string} or later, "
|
||||
f"but you have {git_version_string}.",
|
||||
fg="yellow",
|
||||
)
|
||||
cli_util.print_to_cli(
|
||||
" Git is used by Streamlit Cloud (https://streamlit.io/cloud).",
|
||||
fg="yellow",
|
||||
)
|
||||
cli_util.print_to_cli(
|
||||
" To enable this feature, please update Git.", fg="yellow"
|
||||
)
|
||||
|
||||
|
||||
def load_config_options(flag_options: dict[str, Any]) -> None:
|
||||
"""Load config options from config.toml files, then overlay the ones set by
|
||||
flag_options.
|
||||
|
||||
The "streamlit run" command supports passing Streamlit's config options
|
||||
as flags. This function reads through the config options set via flag,
|
||||
massages them, and passes them to get_config_options() so that they
|
||||
overwrite config option defaults and those loaded from config.toml files.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
flag_options : dict[str, Any]
|
||||
A dict of config options where the keys are the CLI flag version of the
|
||||
config option names.
|
||||
"""
|
||||
# We want to filter out two things: values that are None, and values that
|
||||
# are empty tuples. The latter is a special case that indicates that the
|
||||
# no values were provided, and the config should reset to the default
|
||||
options_from_flags = {
|
||||
name.replace("_", "."): val
|
||||
for name, val in flag_options.items()
|
||||
if val is not None and val != ()
|
||||
}
|
||||
|
||||
# Force a reparse of config files (if they exist). The result is cached
|
||||
# for future calls.
|
||||
config.get_config_options(force_reparse=True, options_from_flags=options_from_flags)
|
||||
|
||||
|
||||
def _install_config_watchers(flag_options: dict[str, Any]) -> None:
|
||||
def on_config_changed(_path):
|
||||
load_config_options(flag_options)
|
||||
|
||||
for filename in CONFIG_FILENAMES:
|
||||
if os.path.exists(filename):
|
||||
watch_file(filename, on_config_changed)
|
||||
|
||||
|
||||
def run(
|
||||
main_script_path: str,
|
||||
is_hello: bool,
|
||||
args: list[str],
|
||||
flag_options: dict[str, Any],
|
||||
*,
|
||||
stop_immediately_for_testing: bool = False,
|
||||
) -> None:
|
||||
"""Run a script in a separate thread and start a server for the app.
|
||||
|
||||
This starts a blocking asyncio eventloop.
|
||||
"""
|
||||
_fix_sys_path(main_script_path)
|
||||
_fix_tornado_crash()
|
||||
_fix_sys_argv(main_script_path, args)
|
||||
_fix_pydeck_mapbox_api_warning()
|
||||
_install_config_watchers(flag_options)
|
||||
|
||||
# Create the server. It won't start running yet.
|
||||
server = Server(main_script_path, is_hello)
|
||||
|
||||
async def run_server() -> None:
|
||||
# Start the server
|
||||
await server.start()
|
||||
_on_server_start(server)
|
||||
|
||||
# Install a signal handler that will shut down the server
|
||||
# and close all our threads
|
||||
_set_up_signal_handler(server)
|
||||
|
||||
# return immediately if we're testing the server start
|
||||
if stop_immediately_for_testing:
|
||||
_LOGGER.debug("Stopping server immediately for testing")
|
||||
server.stop()
|
||||
|
||||
# Wait until `Server.stop` is called, either by our signal handler, or
|
||||
# by a debug websocket session.
|
||||
await server.stopped
|
||||
|
||||
# Run the server. This function will not return until the server is shut down.
|
||||
# FIX RuntimeError: asyncio.run() cannot be called from a running event loop on Python 3.10.16
|
||||
# asyncio.run(run_server())
|
||||
|
||||
# Define a main function to handle the event loop logic
|
||||
async def main():
|
||||
await run_server()
|
||||
|
||||
try:
|
||||
# Check if we're already in an event loop
|
||||
if asyncio.get_running_loop().is_running():
|
||||
# Use `asyncio.create_task` if we're in an async context
|
||||
asyncio.create_task(main())
|
||||
else:
|
||||
# Otherwise, use `asyncio.run`
|
||||
asyncio.run(main())
|
||||
except RuntimeError:
|
||||
# get_running_loop throws RuntimeError if no running event loop
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,38 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from streamlit.runtime.caching.storage.local_disk_cache_storage import (
|
||||
LocalDiskCacheStorageManager,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.runtime.caching.storage import CacheStorageManager
|
||||
|
||||
|
||||
def create_default_cache_storage_manager() -> CacheStorageManager:
|
||||
"""
|
||||
Get the cache storage manager.
|
||||
It would be used both in server.py and in cli.py to have unified cache storage.
|
||||
|
||||
Returns
|
||||
-------
|
||||
CacheStorageManager
|
||||
The cache storage manager.
|
||||
|
||||
"""
|
||||
return LocalDiskCacheStorageManager()
|
||||
411
myenv/lib/python3.11/site-packages/streamlit/web/cli.py
Normal file
411
myenv/lib/python3.11/site-packages/streamlit/web/cli.py
Normal file
@@ -0,0 +1,411 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A script which is run when the Streamlit package is executed."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar
|
||||
|
||||
# We cannot lazy-load click here because its used via decorators.
|
||||
import click
|
||||
|
||||
import streamlit.runtime.caching as caching
|
||||
import streamlit.web.bootstrap as bootstrap
|
||||
from streamlit import config as _config
|
||||
from streamlit.runtime.credentials import Credentials, check_credentials
|
||||
from streamlit.web.cache_storage_manager_config import (
|
||||
create_default_cache_storage_manager,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.config_option import ConfigOption
|
||||
|
||||
ACCEPTED_FILE_EXTENSIONS = ("py", "py3")
|
||||
|
||||
LOG_LEVELS = ("error", "warning", "info", "debug")
|
||||
|
||||
|
||||
def _convert_config_option_to_click_option(
|
||||
config_option: ConfigOption,
|
||||
) -> dict[str, Any]:
|
||||
"""Composes given config option options as options for click lib."""
|
||||
option = f"--{config_option.key}"
|
||||
param = config_option.key.replace(".", "_")
|
||||
description = config_option.description
|
||||
if config_option.deprecated:
|
||||
if description is None:
|
||||
description = ""
|
||||
description += (
|
||||
f"\n {config_option.deprecation_text} - {config_option.expiration_date}"
|
||||
)
|
||||
|
||||
return {
|
||||
"param": param,
|
||||
"description": description,
|
||||
"type": config_option.type,
|
||||
"option": option,
|
||||
"envvar": config_option.env_var,
|
||||
"multiple": config_option.multiple,
|
||||
}
|
||||
|
||||
|
||||
def _make_sensitive_option_callback(config_option: ConfigOption):
|
||||
def callback(_ctx: click.Context, _param: click.Parameter, cli_value) -> None:
|
||||
if cli_value is None:
|
||||
return None
|
||||
raise SystemExit(
|
||||
f"Setting {config_option.key!r} option using the CLI flag is not allowed. "
|
||||
f"Set this option in the configuration file or environment "
|
||||
f"variable: {config_option.env_var!r}"
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def configurator_options(func: F) -> F:
|
||||
"""Decorator that adds config param keys to click dynamically."""
|
||||
for _, value in reversed(_config._config_options_template.items()):
|
||||
parsed_parameter = _convert_config_option_to_click_option(value)
|
||||
if value.sensitive:
|
||||
# Display a warning if the user tries to set sensitive
|
||||
# options using the CLI and exit with non-zero code.
|
||||
click_option_kwargs = {
|
||||
"expose_value": False,
|
||||
"hidden": True,
|
||||
"is_eager": True,
|
||||
"callback": _make_sensitive_option_callback(value),
|
||||
}
|
||||
else:
|
||||
click_option_kwargs = {
|
||||
"show_envvar": True,
|
||||
"envvar": parsed_parameter["envvar"],
|
||||
}
|
||||
config_option = click.option(
|
||||
parsed_parameter["option"],
|
||||
parsed_parameter["param"],
|
||||
help=parsed_parameter["description"],
|
||||
type=parsed_parameter["type"],
|
||||
multiple=parsed_parameter["multiple"],
|
||||
**click_option_kwargs,
|
||||
)
|
||||
func = config_option(func)
|
||||
return func
|
||||
|
||||
|
||||
def _download_remote(main_script_path: str, url_path: str) -> None:
|
||||
"""Fetch remote file at url_path to main_script_path."""
|
||||
import requests
|
||||
|
||||
with open(main_script_path, "wb") as fp:
|
||||
try:
|
||||
resp = requests.get(url_path)
|
||||
resp.raise_for_status()
|
||||
fp.write(resp.content)
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise click.BadParameter(f"Unable to fetch {url_path}.\n{e}")
|
||||
|
||||
|
||||
@click.group(context_settings={"auto_envvar_prefix": "STREAMLIT"})
|
||||
@click.option("--log_level", show_default=True, type=click.Choice(LOG_LEVELS))
|
||||
@click.version_option(prog_name="Streamlit")
|
||||
def main(log_level="info"):
|
||||
"""Try out a demo with:
|
||||
|
||||
$ streamlit hello
|
||||
|
||||
Or use the line below to run your own script:
|
||||
|
||||
$ streamlit run your_script.py
|
||||
""" # noqa: D400
|
||||
|
||||
if log_level:
|
||||
from streamlit.logger import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
LOGGER.warning(
|
||||
"Setting the log level using the --log_level flag is unsupported."
|
||||
"\nUse the --logger.level flag (after your streamlit command) instead."
|
||||
)
|
||||
|
||||
|
||||
@main.command("help")
|
||||
def help():
|
||||
"""Print this help message."""
|
||||
# We use _get_command_line_as_string to run some error checks but don't do
|
||||
# anything with its return value.
|
||||
_get_command_line_as_string()
|
||||
|
||||
assert len(sys.argv) == 2 # This is always true, but let's assert anyway.
|
||||
|
||||
# Pretend user typed 'streamlit --help' instead of 'streamlit help'.
|
||||
sys.argv[1] = "--help"
|
||||
main(prog_name="streamlit")
|
||||
|
||||
|
||||
@main.command("version")
|
||||
def main_version():
|
||||
"""Print Streamlit's version number."""
|
||||
# Pretend user typed 'streamlit --version' instead of 'streamlit version'
|
||||
import sys
|
||||
|
||||
# We use _get_command_line_as_string to run some error checks but don't do
|
||||
# anything with its return value.
|
||||
_get_command_line_as_string()
|
||||
|
||||
assert len(sys.argv) == 2 # This is always true, but let's assert anyway.
|
||||
sys.argv[1] = "--version"
|
||||
main()
|
||||
|
||||
|
||||
@main.command("docs")
|
||||
def main_docs():
|
||||
"""Show help in browser."""
|
||||
click.echo("Showing help page in browser...")
|
||||
from streamlit import cli_util
|
||||
|
||||
cli_util.open_browser("https://docs.streamlit.io")
|
||||
|
||||
|
||||
@main.command("hello")
|
||||
@configurator_options
|
||||
def main_hello(**kwargs):
|
||||
"""Runs the Hello World script."""
|
||||
from streamlit.hello import streamlit_app
|
||||
|
||||
bootstrap.load_config_options(flag_options=kwargs)
|
||||
filename = streamlit_app.__file__
|
||||
_main_run(filename, flag_options=kwargs)
|
||||
|
||||
|
||||
@main.command("run")
|
||||
@configurator_options
|
||||
@click.argument("target", required=True, envvar="STREAMLIT_RUN_TARGET")
|
||||
@click.argument("args", nargs=-1)
|
||||
def main_run(target: str, args=None, **kwargs):
|
||||
"""Run a Python script, piping stderr to Streamlit.
|
||||
|
||||
The script can be local or it can be an url. In the latter case, Streamlit
|
||||
will download the script to a temporary file and runs this file.
|
||||
|
||||
"""
|
||||
from streamlit import url_util
|
||||
|
||||
bootstrap.load_config_options(flag_options=kwargs)
|
||||
|
||||
_, extension = os.path.splitext(target)
|
||||
if extension[1:] not in ACCEPTED_FILE_EXTENSIONS:
|
||||
if extension[1:] == "":
|
||||
raise click.BadArgumentUsage(
|
||||
"Streamlit requires raw Python (.py) files, but the provided file has no extension.\nFor more information, please see https://docs.streamlit.io"
|
||||
)
|
||||
else:
|
||||
raise click.BadArgumentUsage(
|
||||
f"Streamlit requires raw Python (.py) files, not {extension}.\nFor more information, please see https://docs.streamlit.io"
|
||||
)
|
||||
|
||||
if url_util.is_url(target):
|
||||
from streamlit.temporary_directory import TemporaryDirectory
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
path = urlparse(target).path
|
||||
main_script_path = os.path.join(
|
||||
temp_dir, path.strip("/").rsplit("/", 1)[-1]
|
||||
)
|
||||
# if this is a GitHub/Gist blob url, convert to a raw URL first.
|
||||
target = url_util.process_gitblob_url(target)
|
||||
_download_remote(main_script_path, target)
|
||||
_main_run(main_script_path, args, flag_options=kwargs)
|
||||
else:
|
||||
if not os.path.exists(target):
|
||||
raise click.BadParameter(f"File does not exist: {target}")
|
||||
_main_run(target, args, flag_options=kwargs)
|
||||
|
||||
|
||||
def _get_command_line_as_string() -> str | None:
|
||||
import subprocess
|
||||
|
||||
parent = click.get_current_context().parent
|
||||
if parent is None:
|
||||
return None
|
||||
|
||||
if "streamlit.cli" in parent.command_path:
|
||||
raise RuntimeError(
|
||||
"Running streamlit via `python -m streamlit.cli <command>` is"
|
||||
" unsupported. Please use `python -m streamlit <command>` instead."
|
||||
)
|
||||
|
||||
cmd_line_as_list = [parent.command_path]
|
||||
cmd_line_as_list.extend(sys.argv[1:])
|
||||
return subprocess.list2cmdline(cmd_line_as_list)
|
||||
|
||||
|
||||
def _main_run(
|
||||
file,
|
||||
args: list[str] | None = None,
|
||||
flag_options: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
if args is None:
|
||||
args = []
|
||||
|
||||
if flag_options is None:
|
||||
flag_options = {}
|
||||
|
||||
is_hello = _get_command_line_as_string() == "streamlit hello"
|
||||
|
||||
check_credentials()
|
||||
|
||||
bootstrap.run(file, is_hello, args, flag_options)
|
||||
|
||||
|
||||
# SUBCOMMAND: cache
|
||||
|
||||
|
||||
@main.group("cache")
|
||||
def cache():
|
||||
"""Manage the Streamlit cache."""
|
||||
pass
|
||||
|
||||
|
||||
@cache.command("clear")
|
||||
def cache_clear():
|
||||
"""Clear st.cache_data and st.cache_resource caches."""
|
||||
|
||||
# in this `streamlit cache clear` cli command we cannot use the
|
||||
# `cache_storage_manager from runtime (since runtime is not initialized)
|
||||
# so we create a new cache_storage_manager instance that used in runtime,
|
||||
# and call clear_all() method for it.
|
||||
# This will not remove the in-memory cache.
|
||||
cache_storage_manager = create_default_cache_storage_manager()
|
||||
cache_storage_manager.clear_all()
|
||||
caching.cache_resource.clear()
|
||||
|
||||
|
||||
# SUBCOMMAND: config
|
||||
|
||||
|
||||
@main.group("config")
|
||||
def config():
|
||||
"""Manage Streamlit's config settings."""
|
||||
pass
|
||||
|
||||
|
||||
@config.command("show")
|
||||
@configurator_options
|
||||
def config_show(**kwargs):
|
||||
"""Show all of Streamlit's config settings."""
|
||||
|
||||
bootstrap.load_config_options(flag_options=kwargs)
|
||||
|
||||
_config.show_config()
|
||||
|
||||
|
||||
# SUBCOMMAND: activate
|
||||
|
||||
|
||||
@main.group("activate", invoke_without_command=True)
|
||||
@click.pass_context
|
||||
def activate(ctx):
|
||||
"""Activate Streamlit by entering your email."""
|
||||
if not ctx.invoked_subcommand:
|
||||
Credentials.get_current().activate()
|
||||
|
||||
|
||||
@activate.command("reset")
|
||||
def activate_reset():
|
||||
"""Reset Activation Credentials."""
|
||||
Credentials.get_current().reset()
|
||||
|
||||
|
||||
# SUBCOMMAND: test
|
||||
|
||||
|
||||
@main.group("test", hidden=True)
|
||||
def test():
|
||||
"""Internal-only commands used for testing.
|
||||
|
||||
These commands are not included in the output of `streamlit help`.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@test.command("prog_name")
|
||||
def test_prog_name():
|
||||
"""Assert that the program name is set to `streamlit test`.
|
||||
|
||||
This is used by our cli-smoke-tests to verify that the program name is set
|
||||
to `streamlit ...` whether the streamlit binary is invoked directly or via
|
||||
`python -m streamlit ...`.
|
||||
"""
|
||||
# We use _get_command_line_as_string to run some error checks but don't do
|
||||
# anything with its return value.
|
||||
_get_command_line_as_string()
|
||||
|
||||
parent = click.get_current_context().parent
|
||||
|
||||
assert parent is not None
|
||||
assert parent.command_path == "streamlit test"
|
||||
|
||||
|
||||
@main.command("init")
|
||||
@click.argument("directory", required=False)
|
||||
def main_init(directory: str | None = None):
|
||||
"""Initialize a new Streamlit project.
|
||||
|
||||
If DIRECTORY is specified, create it and initialize the project there.
|
||||
Otherwise use the current directory.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
project_dir = Path(directory) if directory else Path.cwd()
|
||||
|
||||
try:
|
||||
project_dir.mkdir(exist_ok=True, parents=True)
|
||||
except OSError as e:
|
||||
raise click.ClickException(f"Failed to create directory: {e}")
|
||||
|
||||
# Create requirements.txt
|
||||
(project_dir / "requirements.txt").write_text("streamlit\n")
|
||||
|
||||
# Create streamlit_app.py
|
||||
(project_dir / "streamlit_app.py").write_text("""import streamlit as st
|
||||
|
||||
st.title("🎈 My new app")
|
||||
st.write(
|
||||
"Let's start building! For help and inspiration, head over to [docs.streamlit.io](https://docs.streamlit.io/)."
|
||||
)
|
||||
""")
|
||||
|
||||
rel_path_str = str(directory) if directory else "."
|
||||
|
||||
click.secho("✨ Created new Streamlit app in ", nl=False)
|
||||
click.secho(f"{rel_path_str}", fg="blue")
|
||||
click.echo("🚀 Run it with: ", nl=False)
|
||||
click.secho(f"streamlit run {rel_path_str}/streamlit_app.py", fg="blue")
|
||||
|
||||
if click.confirm("❓ Run the app now?", default=True):
|
||||
app_path = project_dir / "streamlit_app.py"
|
||||
click.echo("\nStarting Streamlit...")
|
||||
_main_run(str(app_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,26 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from streamlit.web.server.component_request_handler import ComponentRequestHandler
|
||||
from streamlit.web.server.routes import allow_cross_origin_requests
|
||||
from streamlit.web.server.server import Server, server_address_is_unix_socket
|
||||
from streamlit.web.server.stats_request_handler import StatsRequestHandler
|
||||
|
||||
__all__ = [
|
||||
"ComponentRequestHandler",
|
||||
"allow_cross_origin_requests",
|
||||
"Server",
|
||||
"server_address_is_unix_socket",
|
||||
"StatsRequestHandler",
|
||||
]
|
||||
@@ -0,0 +1,93 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
import tornado.web
|
||||
|
||||
from streamlit.logger import get_logger
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
# We agreed on these limitations for the initial release of static file sharing,
|
||||
# based on security concerns from the SiS and Community Cloud teams
|
||||
# The maximum possible size of single serving static file.
|
||||
MAX_APP_STATIC_FILE_SIZE = 200 * 1024 * 1024 # 200 MB
|
||||
# The list of file extensions that we serve with the corresponding Content-Type header.
|
||||
# All files with other extensions will be served with Content-Type: text/plain
|
||||
SAFE_APP_STATIC_FILE_EXTENSIONS = (
|
||||
# Common image types:
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".gif",
|
||||
".webp",
|
||||
# Common font types:
|
||||
".otf",
|
||||
".ttf",
|
||||
".woff",
|
||||
".woff2",
|
||||
# Other types:
|
||||
".pdf",
|
||||
".xml",
|
||||
".json",
|
||||
)
|
||||
|
||||
|
||||
class AppStaticFileHandler(tornado.web.StaticFileHandler):
|
||||
def initialize(self, path: str, default_filename: str | None = None) -> None:
|
||||
super().initialize(path, default_filename)
|
||||
|
||||
def validate_absolute_path(self, root: str, absolute_path: str) -> str | None:
|
||||
full_path = os.path.abspath(absolute_path)
|
||||
|
||||
ret_val = super().validate_absolute_path(root, absolute_path)
|
||||
|
||||
if os.path.isdir(full_path):
|
||||
# we don't want to serve directories, and serve only files
|
||||
raise tornado.web.HTTPError(404)
|
||||
|
||||
if os.path.commonpath([full_path, root]) != root:
|
||||
# Don't allow misbehaving clients to break out of the static files directory
|
||||
_LOGGER.warning(
|
||||
"Serving files outside of the static directory is not supported"
|
||||
)
|
||||
raise tornado.web.HTTPError(404)
|
||||
|
||||
if (
|
||||
os.path.exists(full_path)
|
||||
and os.path.getsize(full_path) > MAX_APP_STATIC_FILE_SIZE
|
||||
):
|
||||
raise tornado.web.HTTPError(
|
||||
404,
|
||||
"File is too large, its size should not exceed "
|
||||
f"{MAX_APP_STATIC_FILE_SIZE} bytes",
|
||||
reason="File is too large",
|
||||
)
|
||||
|
||||
return ret_val
|
||||
|
||||
def set_default_headers(self):
|
||||
# CORS protection is disabled because we need access to this endpoint
|
||||
# from the inner iframe.
|
||||
self.set_header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
def set_extra_headers(self, path: str) -> None:
|
||||
if Path(path).suffix not in SAFE_APP_STATIC_FILE_EXTENSIONS:
|
||||
self.set_header("Content-Type", "text/plain")
|
||||
self.set_header("X-Content-Type-Options", "nosniff")
|
||||
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from authlib.integrations.base_client import ( # type: ignore[import-untyped]
|
||||
FrameworkIntegration,
|
||||
)
|
||||
|
||||
from streamlit.runtime.secrets import AttrDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from streamlit.web.server.oidc_mixin import TornadoOAuth
|
||||
|
||||
|
||||
class TornadoIntegration(FrameworkIntegration): # type: ignore[misc]
|
||||
def update_token(self, token, refresh_token=None, access_token=None):
|
||||
"""We do not support access token refresh, since we obtain and operate only on
|
||||
identity tokens. We override this method explicitly to implement all abstract
|
||||
methods of base class.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load_config(
|
||||
oauth: TornadoOAuth, name: str, params: Sequence[str]
|
||||
) -> dict[str, Any]:
|
||||
"""Configure Authlib integration with provider parameters
|
||||
specified in secrets.toml.
|
||||
"""
|
||||
|
||||
# oauth.config here is an auth section from secrets.toml
|
||||
# We parse it here to transform nested AttrDict (for client_kwargs value)
|
||||
# to dict so Authlib can work with it under the hood.
|
||||
if not oauth.config:
|
||||
return {}
|
||||
|
||||
prepared_config = {}
|
||||
for key in params:
|
||||
value = oauth.config.get(name, {}).get(key, None)
|
||||
if isinstance(value, AttrDict):
|
||||
# We want to modify client_kwargs further after loading server metadata
|
||||
value = value.to_dict()
|
||||
if value is not None:
|
||||
prepared_config[key] = value
|
||||
return prepared_config
|
||||
@@ -0,0 +1,247 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hmac
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Final
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import tornado.concurrent
|
||||
import tornado.locks
|
||||
import tornado.netutil
|
||||
import tornado.web
|
||||
import tornado.websocket
|
||||
from tornado.escape import utf8
|
||||
from tornado.websocket import WebSocketHandler
|
||||
|
||||
from streamlit import config
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.proto.BackMsg_pb2 import BackMsg
|
||||
from streamlit.runtime import Runtime, SessionClient, SessionClientDisconnectedError
|
||||
from streamlit.runtime.runtime_util import serialize_forward_msg
|
||||
from streamlit.web.server.server_util import (
|
||||
AUTH_COOKIE_NAME,
|
||||
is_url_from_allowed_origins,
|
||||
is_xsrf_enabled,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable
|
||||
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
|
||||
class BrowserWebSocketHandler(WebSocketHandler, SessionClient):
|
||||
"""Handles a WebSocket connection from the browser."""
|
||||
|
||||
def initialize(self, runtime: Runtime) -> None:
|
||||
self._runtime = runtime
|
||||
self._session_id: str | None = None
|
||||
# The XSRF cookie is normally set when xsrf_form_html is used, but in a
|
||||
# pure-Javascript application that does not use any regular forms we just
|
||||
# need to read the self.xsrf_token manually to set the cookie as a side
|
||||
# effect. See https://www.tornadoweb.org/en/stable/guide/security.html#cross-site-request-forgery-protection
|
||||
# for more details.
|
||||
if is_xsrf_enabled():
|
||||
_ = self.xsrf_token
|
||||
|
||||
def get_signed_cookie(
|
||||
self,
|
||||
name: str,
|
||||
value: str | None = None,
|
||||
max_age_days: float = 31,
|
||||
min_version: int | None = None,
|
||||
) -> bytes | None:
|
||||
"""Get a signed cookie from the request. Added for compatibility with
|
||||
Tornado < 6.3.0.
|
||||
|
||||
See release notes: https://www.tornadoweb.org/en/stable/releases/v6.3.0.html#deprecation-notices
|
||||
"""
|
||||
try:
|
||||
return super().get_signed_cookie(name, value, max_age_days, min_version)
|
||||
except AttributeError:
|
||||
return super().get_secure_cookie(name, value, max_age_days, min_version)
|
||||
|
||||
def check_origin(self, origin: str) -> bool:
|
||||
"""Set up CORS."""
|
||||
return super().check_origin(origin) or is_url_from_allowed_origins(origin)
|
||||
|
||||
def _validate_xsrf_token(self, supplied_token: str) -> bool:
|
||||
"""Inspired by tornado.web.RequestHandler.check_xsrf_cookie method,
|
||||
to check the XSRF token passed in Websocket connection header.
|
||||
"""
|
||||
_, token, _ = self._decode_xsrf_token(supplied_token)
|
||||
_, expected_token, _ = self._get_raw_xsrf_token()
|
||||
|
||||
decoded_token = utf8(token)
|
||||
decoded_expected_token = utf8(expected_token)
|
||||
|
||||
if not decoded_token or not decoded_expected_token:
|
||||
return False
|
||||
return hmac.compare_digest(decoded_token, decoded_expected_token)
|
||||
|
||||
def _parse_user_cookie(self, raw_cookie_value: bytes) -> dict[str, Any]:
|
||||
"""Process the user cookie and extract the user info after
|
||||
validating the origin. Origin is validated for security reasons.
|
||||
"""
|
||||
cookie_value = json.loads(raw_cookie_value)
|
||||
user_info = {}
|
||||
|
||||
cookie_value_origin = cookie_value.get("origin", None)
|
||||
parsed_origin_from_header = urlparse(self.request.headers["Origin"])
|
||||
expected_origin_value = (
|
||||
parsed_origin_from_header.scheme + "://" + parsed_origin_from_header.netloc
|
||||
)
|
||||
if cookie_value_origin == expected_origin_value:
|
||||
user_info["is_logged_in"] = cookie_value.get("is_logged_in", False)
|
||||
del cookie_value["origin"]
|
||||
del cookie_value["is_logged_in"]
|
||||
user_info.update(cookie_value)
|
||||
|
||||
return user_info
|
||||
|
||||
def write_forward_msg(self, msg: ForwardMsg) -> None:
|
||||
"""Send a ForwardMsg to the browser."""
|
||||
try:
|
||||
self.write_message(serialize_forward_msg(msg), binary=True)
|
||||
except tornado.websocket.WebSocketClosedError as e:
|
||||
raise SessionClientDisconnectedError from e
|
||||
|
||||
def select_subprotocol(self, subprotocols: list[str]) -> str | None:
|
||||
"""Return the first subprotocol in the given list.
|
||||
|
||||
This method is used by Tornado to select a protocol when the
|
||||
Sec-WebSocket-Protocol header is set in an HTTP Upgrade request.
|
||||
|
||||
NOTE: We repurpose the Sec-WebSocket-Protocol header here in a slightly
|
||||
unfortunate (but necessary) way. The browser WebSocket API doesn't allow us to
|
||||
set arbitrary HTTP headers, and this header is the only one where we have the
|
||||
ability to set it to arbitrary values, so we use it to pass tokens (in this
|
||||
case, the previous session ID to allow us to reconnect to it) from client to
|
||||
server as the *third* value in the list.
|
||||
|
||||
The reason why the auth token is set as the third value is that:
|
||||
- when Sec-WebSocket-Protocol is set, many clients expect the server to
|
||||
respond with a selected subprotocol to use. We don't want that reply to be
|
||||
the session token, so we by convention have the client always set the first
|
||||
protocol to "streamlit" and select that.
|
||||
- the second protocol in the list is reserved in some deployment environments
|
||||
for an auth token that we currently don't use
|
||||
"""
|
||||
if subprotocols:
|
||||
return subprotocols[0]
|
||||
|
||||
return None
|
||||
|
||||
def open(self, *args, **kwargs) -> Awaitable[None] | None:
|
||||
user_info: dict[str, str | bool | None] = {}
|
||||
|
||||
existing_session_id = None
|
||||
try:
|
||||
ws_protocols = [
|
||||
p.strip()
|
||||
for p in self.request.headers["Sec-Websocket-Protocol"].split(",")
|
||||
]
|
||||
|
||||
raw_cookie_value = self.get_signed_cookie(AUTH_COOKIE_NAME)
|
||||
if is_xsrf_enabled() and raw_cookie_value:
|
||||
csrf_protocol_value = ws_protocols[1]
|
||||
|
||||
if self._validate_xsrf_token(csrf_protocol_value):
|
||||
user_info.update(self._parse_user_cookie(raw_cookie_value))
|
||||
|
||||
if len(ws_protocols) >= 3:
|
||||
# See the NOTE in the docstring of the `select_subprotocol` method above
|
||||
# for a detailed explanation of why this is done.
|
||||
existing_session_id = ws_protocols[2]
|
||||
except KeyError:
|
||||
# Just let existing_session_id=None if we run into any error while trying to
|
||||
# extract it from the Sec-Websocket-Protocol header.
|
||||
pass
|
||||
|
||||
self._session_id = self._runtime.connect_session(
|
||||
client=self,
|
||||
user_info=user_info,
|
||||
existing_session_id=existing_session_id,
|
||||
)
|
||||
return None
|
||||
|
||||
def on_close(self) -> None:
|
||||
if not self._session_id:
|
||||
return
|
||||
self._runtime.disconnect_session(self._session_id)
|
||||
self._session_id = None
|
||||
|
||||
def get_compression_options(self) -> dict[Any, Any] | None:
|
||||
"""Enable WebSocket compression.
|
||||
|
||||
Returning an empty dict enables websocket compression. Returning
|
||||
None disables it.
|
||||
|
||||
(See the docstring in the parent class.)
|
||||
"""
|
||||
if config.get_option("server.enableWebsocketCompression"):
|
||||
return {}
|
||||
return None
|
||||
|
||||
def on_message(self, payload: str | bytes) -> None:
|
||||
if not self._session_id:
|
||||
return
|
||||
|
||||
try:
|
||||
if isinstance(payload, str):
|
||||
# Sanity check. (The frontend should only be sending us bytes;
|
||||
# Protobuf.ParseFromString does not accept str input.)
|
||||
raise RuntimeError(
|
||||
"WebSocket received an unexpected `str` message. "
|
||||
"(We expect `bytes` only.)"
|
||||
)
|
||||
|
||||
msg = BackMsg()
|
||||
msg.ParseFromString(payload)
|
||||
_LOGGER.debug("Received the following back message:\n%s", msg)
|
||||
|
||||
except Exception as ex:
|
||||
_LOGGER.exception("Error deserializing back message")
|
||||
self._runtime.handle_backmsg_deserialization_exception(self._session_id, ex)
|
||||
return
|
||||
|
||||
# "debug_disconnect_websocket" and "debug_shutdown_runtime" are special
|
||||
# developmentMode-only messages used in e2e tests to test reconnect handling and
|
||||
# disabling widgets.
|
||||
if msg.WhichOneof("type") == "debug_disconnect_websocket":
|
||||
if config.get_option("global.developmentMode") or config.get_option(
|
||||
"global.e2eTest"
|
||||
):
|
||||
self.close()
|
||||
else:
|
||||
_LOGGER.warning(
|
||||
"Client tried to disconnect websocket when not in development mode or e2e testing."
|
||||
)
|
||||
elif msg.WhichOneof("type") == "debug_shutdown_runtime":
|
||||
if config.get_option("global.developmentMode") or config.get_option(
|
||||
"global.e2eTest"
|
||||
):
|
||||
self._runtime.stop()
|
||||
else:
|
||||
_LOGGER.warning(
|
||||
"Client tried to shut down runtime when not in development mode or e2e testing."
|
||||
)
|
||||
else:
|
||||
# AppSession handles all other BackMsg types.
|
||||
self._runtime.handle_backmsg(self._session_id, msg)
|
||||
@@ -0,0 +1,116 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
import tornado.web
|
||||
|
||||
import streamlit.web.server.routes
|
||||
from streamlit.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.components.types.base_component_registry import BaseComponentRegistry
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
|
||||
class ComponentRequestHandler(tornado.web.RequestHandler):
|
||||
def initialize(self, registry: BaseComponentRegistry):
|
||||
self._registry = registry
|
||||
|
||||
def get(self, path: str) -> None:
|
||||
parts = path.split("/")
|
||||
component_name = parts[0]
|
||||
component_root = self._registry.get_component_path(component_name)
|
||||
if component_root is None:
|
||||
self.write("not found")
|
||||
self.set_status(404)
|
||||
return
|
||||
|
||||
# follow symlinks to get an accurate normalized path
|
||||
component_root = os.path.realpath(component_root)
|
||||
filename = "/".join(parts[1:])
|
||||
abspath = os.path.normpath(os.path.join(component_root, filename))
|
||||
|
||||
# Do NOT expose anything outside of the component root.
|
||||
if os.path.commonpath([component_root, abspath]) != component_root:
|
||||
self.write("forbidden")
|
||||
self.set_status(403)
|
||||
return
|
||||
try:
|
||||
with open(abspath, "rb") as file:
|
||||
contents = file.read()
|
||||
except OSError as e:
|
||||
_LOGGER.error(
|
||||
"ComponentRequestHandler: GET %s read error", abspath, exc_info=e
|
||||
)
|
||||
self.write("read error")
|
||||
self.set_status(404)
|
||||
return
|
||||
|
||||
self.write(contents)
|
||||
self.set_header("Content-Type", self.get_content_type(abspath))
|
||||
|
||||
self.set_extra_headers(path)
|
||||
|
||||
def set_extra_headers(self, path: str) -> None:
|
||||
"""Disable cache for HTML files.
|
||||
|
||||
Other assets like JS and CSS are suffixed with their hash, so they can
|
||||
be cached indefinitely.
|
||||
"""
|
||||
is_index_url = len(path) == 0
|
||||
|
||||
if is_index_url or path.endswith(".html"):
|
||||
self.set_header("Cache-Control", "no-cache")
|
||||
else:
|
||||
self.set_header("Cache-Control", "public")
|
||||
|
||||
def set_default_headers(self) -> None:
|
||||
if streamlit.web.server.routes.allow_cross_origin_requests():
|
||||
self.set_header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
def options(self) -> None:
|
||||
"""/OPTIONS handler for preflight CORS checks."""
|
||||
self.set_status(204)
|
||||
self.finish()
|
||||
|
||||
@staticmethod
|
||||
def get_content_type(abspath: str) -> str:
|
||||
"""Returns the ``Content-Type`` header to be used for this request.
|
||||
From tornado.web.StaticFileHandler.
|
||||
"""
|
||||
mime_type, encoding = mimetypes.guess_type(abspath)
|
||||
# per RFC 6713, use the appropriate type for a gzip compressed file
|
||||
if encoding == "gzip":
|
||||
return "application/gzip"
|
||||
# As of 2015-07-21 there is no bzip2 encoding defined at
|
||||
# http://www.iana.org/assignments/media-types/media-types.xhtml
|
||||
# So for that (and any other encoding), use octet-stream.
|
||||
elif encoding is not None:
|
||||
return "application/octet-stream"
|
||||
elif mime_type is not None:
|
||||
return mime_type
|
||||
# if mime_type not detected, use application/octet-stream
|
||||
else:
|
||||
return "application/octet-stream"
|
||||
|
||||
@staticmethod
|
||||
def get_url(file_id: str) -> str:
|
||||
"""Return the URL for a component file with the given ID."""
|
||||
return f"components/{file_id}"
|
||||
@@ -0,0 +1,141 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from urllib.parse import quote
|
||||
|
||||
import tornado.web
|
||||
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.media_file_storage import MediaFileKind, MediaFileStorageError
|
||||
from streamlit.runtime.memory_media_file_storage import (
|
||||
MemoryMediaFileStorage,
|
||||
get_extension_for_mimetype,
|
||||
)
|
||||
from streamlit.web.server import allow_cross_origin_requests
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
class MediaFileHandler(tornado.web.StaticFileHandler):
|
||||
_storage: MemoryMediaFileStorage
|
||||
|
||||
@classmethod
|
||||
def initialize_storage(cls, storage: MemoryMediaFileStorage) -> None:
|
||||
"""Set the MemoryMediaFileStorage object used by instances of this
|
||||
handler. Must be called on server startup.
|
||||
"""
|
||||
# This is a class method, rather than an instance method, because
|
||||
# `get_content()` is a class method and needs to access the storage
|
||||
# instance.
|
||||
cls._storage = storage
|
||||
|
||||
def set_default_headers(self) -> None:
|
||||
if allow_cross_origin_requests():
|
||||
self.set_header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
def set_extra_headers(self, path: str) -> None:
|
||||
"""Add Content-Disposition header for downloadable files.
|
||||
|
||||
Set header value to "attachment" indicating that file should be saved
|
||||
locally instead of displaying inline in browser.
|
||||
|
||||
We also set filename to specify the filename for downloaded files.
|
||||
Used for serving downloadable files, like files stored via the
|
||||
`st.download_button` widget.
|
||||
"""
|
||||
media_file = self._storage.get_file(path)
|
||||
|
||||
if media_file and media_file.kind == MediaFileKind.DOWNLOADABLE:
|
||||
filename = media_file.filename
|
||||
|
||||
if not filename:
|
||||
filename = f"streamlit_download{get_extension_for_mimetype(media_file.mimetype)}"
|
||||
|
||||
try:
|
||||
# Check that the value can be encoded in latin1. Latin1 is
|
||||
# the default encoding for headers.
|
||||
filename.encode("latin1")
|
||||
file_expr = f'filename="{filename}"'
|
||||
except UnicodeEncodeError:
|
||||
# RFC5987 syntax.
|
||||
# See: https://datatracker.ietf.org/doc/html/rfc5987
|
||||
file_expr = f"filename*=utf-8''{quote(filename)}"
|
||||
|
||||
self.set_header("Content-Disposition", f"attachment; {file_expr}")
|
||||
|
||||
# Overriding StaticFileHandler to use the MediaFileManager
|
||||
#
|
||||
# From the Tornado docs:
|
||||
# To replace all interaction with the filesystem (e.g. to serve
|
||||
# static content from a database), override `get_content`,
|
||||
# `get_content_size`, `get_modified_time`, `get_absolute_path`, and
|
||||
# `validate_absolute_path`.
|
||||
def validate_absolute_path(self, root: str, absolute_path: str) -> str:
|
||||
try:
|
||||
self._storage.get_file(absolute_path)
|
||||
except MediaFileStorageError:
|
||||
_LOGGER.error("MediaFileHandler: Missing file %s", absolute_path)
|
||||
raise tornado.web.HTTPError(404, "not found")
|
||||
|
||||
return absolute_path
|
||||
|
||||
def get_content_size(self) -> int:
|
||||
abspath = self.absolute_path
|
||||
if abspath is None:
|
||||
return 0
|
||||
|
||||
media_file = self._storage.get_file(abspath)
|
||||
return media_file.content_size
|
||||
|
||||
def get_modified_time(self) -> None:
|
||||
# We do not track last modified time, but this can be improved to
|
||||
# allow caching among files in the MediaFileManager
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_absolute_path(cls, root: str, path: str) -> str:
|
||||
# All files are stored in memory, so the absolute path is just the
|
||||
# path itself. In the MediaFileHandler, it's just the filename
|
||||
return path
|
||||
|
||||
@classmethod
|
||||
def get_content(
|
||||
cls, abspath: str, start: int | None = None, end: int | None = None
|
||||
):
|
||||
_LOGGER.debug("MediaFileHandler: GET %s", abspath)
|
||||
|
||||
try:
|
||||
# abspath is the hash as used `get_absolute_path`
|
||||
media_file = cls._storage.get_file(abspath)
|
||||
except Exception:
|
||||
_LOGGER.error("MediaFileHandler: Missing file %s", abspath)
|
||||
return None
|
||||
|
||||
_LOGGER.debug(
|
||||
"MediaFileHandler: Sending %s file %s", media_file.mimetype, abspath
|
||||
)
|
||||
|
||||
# If there is no start and end, just return the full content
|
||||
if start is None and end is None:
|
||||
return media_file.content
|
||||
|
||||
if start is None:
|
||||
start = 0
|
||||
if end is None:
|
||||
end = len(media_file.content)
|
||||
|
||||
# content is bytes that work just by slicing supplied by start and end
|
||||
return media_file.content[start:end]
|
||||
@@ -0,0 +1,176 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import tornado.web
|
||||
|
||||
from streamlit.auth_util import (
|
||||
AuthCache,
|
||||
decode_provider_token,
|
||||
generate_default_provider_section,
|
||||
get_secrets_auth_section,
|
||||
)
|
||||
from streamlit.errors import StreamlitAuthError
|
||||
from streamlit.url_util import make_url_path
|
||||
from streamlit.web.server.oidc_mixin import TornadoOAuth, TornadoOAuth2App
|
||||
from streamlit.web.server.server_util import AUTH_COOKIE_NAME
|
||||
|
||||
auth_cache = AuthCache()
|
||||
|
||||
|
||||
def create_oauth_client(provider: str) -> tuple[TornadoOAuth2App, str]:
|
||||
"""Create an OAuth client for the given provider based on secrets.toml configuration."""
|
||||
auth_section = get_secrets_auth_section()
|
||||
if auth_section:
|
||||
redirect_uri = auth_section.get("redirect_uri", None)
|
||||
config = auth_section.to_dict()
|
||||
else:
|
||||
config = {}
|
||||
redirect_uri = "/"
|
||||
|
||||
provider_section = config.setdefault(provider, {})
|
||||
|
||||
if not provider_section and provider == "default":
|
||||
provider_section = generate_default_provider_section(auth_section)
|
||||
config["default"] = provider_section
|
||||
|
||||
provider_client_kwargs = provider_section.setdefault("client_kwargs", {})
|
||||
if "scope" not in provider_client_kwargs:
|
||||
provider_client_kwargs["scope"] = "openid email profile"
|
||||
if "prompt" not in provider_client_kwargs:
|
||||
provider_client_kwargs["prompt"] = "select_account"
|
||||
|
||||
oauth = TornadoOAuth(config, cache=auth_cache)
|
||||
oauth.register(provider)
|
||||
return oauth.create_client(provider), redirect_uri
|
||||
|
||||
|
||||
class AuthHandlerMixin(tornado.web.RequestHandler):
|
||||
"""Mixin for handling auth cookies. Added for compatibility with Tornado < 6.3.0."""
|
||||
|
||||
def initialize(self, base_url: str) -> None:
|
||||
self.base_url = base_url
|
||||
|
||||
def redirect_to_base(self) -> None:
|
||||
self.redirect(make_url_path(self.base_url, "/"))
|
||||
|
||||
def set_auth_cookie(self, user_info: dict[str, Any]) -> None:
|
||||
serialized_cookie_value = json.dumps(user_info)
|
||||
try:
|
||||
# We don't specify Tornado secure flag here because it leads to missing cookie on Safari.
|
||||
# The OIDC flow should work only on secure context anyway (localhost or HTTPS),
|
||||
# so specifying the secure flag here will not add anything in terms of security.
|
||||
self.set_signed_cookie(
|
||||
AUTH_COOKIE_NAME,
|
||||
serialized_cookie_value,
|
||||
httpOnly=True,
|
||||
)
|
||||
except AttributeError:
|
||||
self.set_secure_cookie(
|
||||
AUTH_COOKIE_NAME,
|
||||
serialized_cookie_value,
|
||||
httponly=True,
|
||||
)
|
||||
|
||||
def clear_auth_cookie(self) -> None:
|
||||
self.clear_cookie(AUTH_COOKIE_NAME)
|
||||
|
||||
|
||||
class AuthLoginHandler(AuthHandlerMixin, tornado.web.RequestHandler):
|
||||
async def get(self):
|
||||
"""Redirect to the OAuth provider login page."""
|
||||
provider = self._parse_provider_token()
|
||||
if provider is None:
|
||||
self.redirect_to_base()
|
||||
return
|
||||
|
||||
client, redirect_uri = create_oauth_client(provider)
|
||||
try:
|
||||
client.authorize_redirect(self, redirect_uri)
|
||||
except Exception as e:
|
||||
self.send_error(400, reason=str(e))
|
||||
|
||||
def _parse_provider_token(self) -> str | None:
|
||||
provider_token = self.get_argument("provider", None)
|
||||
try:
|
||||
if provider_token is None:
|
||||
raise StreamlitAuthError("Missing provider token")
|
||||
payload = decode_provider_token(provider_token)
|
||||
except StreamlitAuthError:
|
||||
return None
|
||||
|
||||
return payload["provider"]
|
||||
|
||||
|
||||
class AuthLogoutHandler(AuthHandlerMixin, tornado.web.RequestHandler):
|
||||
def get(self):
|
||||
self.clear_auth_cookie()
|
||||
self.redirect_to_base()
|
||||
|
||||
|
||||
class AuthCallbackHandler(AuthHandlerMixin, tornado.web.RequestHandler):
|
||||
async def get(self):
|
||||
provider = self._get_provider_by_state()
|
||||
origin = self._get_origin_from_secrets()
|
||||
if origin is None:
|
||||
self.redirect_to_base()
|
||||
return
|
||||
|
||||
error = self.get_argument("error", None)
|
||||
if error:
|
||||
self.redirect_to_base()
|
||||
return
|
||||
|
||||
if provider is None:
|
||||
self.redirect_to_base()
|
||||
return
|
||||
|
||||
client, _ = create_oauth_client(provider)
|
||||
token = client.authorize_access_token(self)
|
||||
user = token.get("userinfo")
|
||||
|
||||
cookie_value = dict(user, origin=origin, is_logged_in=True)
|
||||
if user:
|
||||
self.set_auth_cookie(cookie_value)
|
||||
self.redirect_to_base()
|
||||
|
||||
def _get_provider_by_state(self) -> str | None:
|
||||
state_code_from_url = self.get_argument("state")
|
||||
current_cache_keys = list(auth_cache.get_dict().keys())
|
||||
state_provider_mapping = {}
|
||||
for key in current_cache_keys:
|
||||
_, _, recorded_provider, code = key.split("_")
|
||||
state_provider_mapping[code] = recorded_provider
|
||||
|
||||
provider: str | None = state_provider_mapping.get(state_code_from_url, None)
|
||||
return provider
|
||||
|
||||
def _get_origin_from_secrets(self) -> str | None:
|
||||
redirect_uri = None
|
||||
auth_section = get_secrets_auth_section()
|
||||
if auth_section:
|
||||
redirect_uri = auth_section.get("redirect_uri", None)
|
||||
|
||||
if not redirect_uri:
|
||||
return None
|
||||
|
||||
redirect_uri_parsed = urlparse(redirect_uri)
|
||||
origin_from_redirect_uri: str = (
|
||||
redirect_uri_parsed.scheme + "://" + redirect_uri_parsed.netloc
|
||||
)
|
||||
return origin_from_redirect_uri
|
||||
@@ -0,0 +1,108 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tornado.web
|
||||
from authlib.integrations.base_client import ( # type: ignore[import-untyped]
|
||||
BaseApp,
|
||||
BaseOAuth,
|
||||
OAuth2Mixin,
|
||||
OAuthError,
|
||||
OpenIDMixin,
|
||||
)
|
||||
from authlib.integrations.requests_client import ( # type: ignore[import-untyped]
|
||||
OAuth2Session,
|
||||
)
|
||||
|
||||
from streamlit.web.server.authlib_tornado_integration import TornadoIntegration
|
||||
|
||||
|
||||
class TornadoOAuth2App(OAuth2Mixin, OpenIDMixin, BaseApp): # type: ignore[misc]
|
||||
client_cls = OAuth2Session
|
||||
|
||||
def load_server_metadata(self):
|
||||
"""We enforce S256 code challenge method if it is supported by the server."""
|
||||
result = super().load_server_metadata()
|
||||
if "S256" in result.get("code_challenge_methods_supported", []):
|
||||
self.client_kwargs["code_challenge_method"] = "S256"
|
||||
return result
|
||||
|
||||
def authorize_redirect(
|
||||
self, request_handler: tornado.web.RequestHandler, redirect_uri=None, **kwargs
|
||||
):
|
||||
"""Create a HTTP Redirect for Authorization Endpoint.
|
||||
|
||||
:param request_handler: HTTP request instance from Tornado.
|
||||
:param redirect_uri: Callback or redirect URI for authorization.
|
||||
:param kwargs: Extra parameters to include.
|
||||
:return: A HTTP redirect response.
|
||||
"""
|
||||
auth_context = self.create_authorization_url(redirect_uri, **kwargs)
|
||||
self._save_authorize_data(redirect_uri=redirect_uri, **auth_context)
|
||||
request_handler.redirect(auth_context["url"], status=302)
|
||||
|
||||
def authorize_access_token(
|
||||
self, request_handler: tornado.web.RequestHandler, **kwargs
|
||||
):
|
||||
"""
|
||||
:param request_handler: HTTP request instance from Tornado.
|
||||
:return: A token dict.
|
||||
"""
|
||||
error = request_handler.get_argument("error", None)
|
||||
if error:
|
||||
description = request_handler.get_argument("error_description", None)
|
||||
raise OAuthError(error=error, description=description)
|
||||
|
||||
params = {
|
||||
"code": request_handler.get_argument("code"),
|
||||
"state": request_handler.get_argument("state"),
|
||||
}
|
||||
|
||||
assert self.framework.cache is not None
|
||||
session = None
|
||||
|
||||
claims_options = kwargs.pop("claims_options", None)
|
||||
state_data = self.framework.get_state_data(session, params.get("state"))
|
||||
self.framework.clear_state_data(session, params.get("state"))
|
||||
params = self._format_state_params(state_data, params)
|
||||
token = self.fetch_access_token(**params, **kwargs)
|
||||
|
||||
if "id_token" in token and "nonce" in state_data:
|
||||
userinfo = self.parse_id_token(
|
||||
token, nonce=state_data["nonce"], claims_options=claims_options
|
||||
)
|
||||
token = {**token, "userinfo": userinfo}
|
||||
return token
|
||||
|
||||
def _save_authorize_data(self, **kwargs):
|
||||
"""Authlib underlying uses the concept of "session" to store state data.
|
||||
In Tornado, we don't have a session, so we use the framework's cache option.
|
||||
"""
|
||||
state = kwargs.pop("state", None)
|
||||
if state:
|
||||
assert self.framework.cache is not None
|
||||
session = None
|
||||
self.framework.set_state_data(session, state, kwargs)
|
||||
else:
|
||||
raise RuntimeError("Missing state value")
|
||||
|
||||
|
||||
class TornadoOAuth(BaseOAuth): # type: ignore[misc]
|
||||
oauth2_client_cls = TornadoOAuth2App
|
||||
framework_integration_cls = TornadoIntegration
|
||||
|
||||
def __init__(self, config=None, cache=None, fetch_token=None, update_token=None):
|
||||
super().__init__(
|
||||
cache=cache, fetch_token=fetch_token, update_token=update_token
|
||||
)
|
||||
self.config = config
|
||||
@@ -0,0 +1,296 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Final
|
||||
|
||||
import tornado.web
|
||||
|
||||
from streamlit import config, file_util
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.runtime_util import serialize_forward_msg
|
||||
from streamlit.web.server.server_util import (
|
||||
emit_endpoint_deprecation_notice,
|
||||
is_xsrf_enabled,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
|
||||
def allow_cross_origin_requests() -> bool:
|
||||
"""True if cross-origin requests are allowed.
|
||||
|
||||
We only allow cross-origin requests when CORS protection has been disabled
|
||||
with server.enableCORS=False or if using the Node server. When using the
|
||||
Node server, we have a dev and prod port, which count as two origins.
|
||||
|
||||
"""
|
||||
return not config.get_option("server.enableCORS") or config.get_option(
|
||||
"global.developmentMode"
|
||||
)
|
||||
|
||||
|
||||
class StaticFileHandler(tornado.web.StaticFileHandler):
|
||||
def initialize(
|
||||
self,
|
||||
path: str,
|
||||
default_filename: str | None = None,
|
||||
reserved_paths: Sequence[str] = (),
|
||||
):
|
||||
self._reserved_paths = reserved_paths
|
||||
|
||||
super().initialize(path, default_filename)
|
||||
|
||||
def set_extra_headers(self, path: str) -> None:
|
||||
"""Disable cache for HTML files.
|
||||
|
||||
Other assets like JS and CSS are suffixed with their hash, so they can
|
||||
be cached indefinitely.
|
||||
"""
|
||||
is_index_url = len(path) == 0
|
||||
if is_index_url or path.endswith(".html"):
|
||||
self.set_header("Cache-Control", "no-cache")
|
||||
else:
|
||||
self.set_header("Cache-Control", "public")
|
||||
|
||||
def validate_absolute_path(self, root: str, absolute_path: str) -> str | None:
|
||||
try:
|
||||
return super().validate_absolute_path(root, absolute_path)
|
||||
except tornado.web.HTTPError as e:
|
||||
# If the file is not found, and there are no reserved paths,
|
||||
# we try to serve the default file and allow the frontend to handle the issue.
|
||||
if e.status_code == 404:
|
||||
url_path = self.path
|
||||
# self.path is OS specific file path, we convert it to a URL path
|
||||
# for checking it against reserved paths.
|
||||
if os.path.sep != "/":
|
||||
url_path = url_path.replace(os.path.sep, "/")
|
||||
if any(url_path.endswith(x) for x in self._reserved_paths):
|
||||
raise e
|
||||
|
||||
self.path = self.parse_url_path(self.default_filename or "index.html")
|
||||
absolute_path = self.get_absolute_path(self.root, self.path)
|
||||
return super().validate_absolute_path(root, absolute_path)
|
||||
|
||||
raise e
|
||||
|
||||
def write_error(self, status_code: int, **kwargs) -> None:
|
||||
if status_code == 404:
|
||||
index_file = os.path.join(file_util.get_static_dir(), "index.html")
|
||||
self.render(index_file)
|
||||
else:
|
||||
super().write_error(status_code, **kwargs)
|
||||
|
||||
|
||||
class AddSlashHandler(tornado.web.RequestHandler):
|
||||
@tornado.web.addslash
|
||||
def get(self):
|
||||
pass
|
||||
|
||||
|
||||
class RemoveSlashHandler(tornado.web.RequestHandler):
|
||||
@tornado.web.removeslash
|
||||
def get(self):
|
||||
pass
|
||||
|
||||
|
||||
class _SpecialRequestHandler(tornado.web.RequestHandler):
|
||||
"""Superclass for "special" endpoints, like /healthz."""
|
||||
|
||||
def set_default_headers(self):
|
||||
self.set_header("Cache-Control", "no-cache")
|
||||
if allow_cross_origin_requests():
|
||||
self.set_header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
def options(self):
|
||||
"""/OPTIONS handler for preflight CORS checks.
|
||||
|
||||
When a browser is making a CORS request, it may sometimes first
|
||||
send an OPTIONS request, to check whether the server understands the
|
||||
CORS protocol. This is optional, and doesn't happen for every request
|
||||
or in every browser. If an OPTIONS request does get sent, and is not
|
||||
then handled by the server, the browser will fail the underlying
|
||||
request.
|
||||
|
||||
The proper way to handle this is to send a 204 response ("no content")
|
||||
with the CORS headers attached. (These headers are automatically added
|
||||
to every outgoing response, including OPTIONS responses,
|
||||
via set_default_headers().)
|
||||
|
||||
See https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
|
||||
"""
|
||||
self.set_status(204)
|
||||
self.finish()
|
||||
|
||||
|
||||
class HealthHandler(_SpecialRequestHandler):
|
||||
def initialize(self, callback):
|
||||
"""Initialize the handler.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
callback : callable
|
||||
A function that returns True if the server is healthy
|
||||
|
||||
"""
|
||||
self._callback = callback
|
||||
|
||||
async def get(self):
|
||||
await self.handle_request()
|
||||
|
||||
# Some monitoring services only support the HTTP HEAD method for requests to
|
||||
# healthcheck endpoints, so we support HEAD as well to play nicely with them.
|
||||
async def head(self):
|
||||
await self.handle_request()
|
||||
|
||||
async def handle_request(self):
|
||||
if self.request.uri and "_stcore/" not in self.request.uri:
|
||||
new_path = (
|
||||
"/_stcore/script-health-check"
|
||||
if "script-health-check" in self.request.uri
|
||||
else "/_stcore/health"
|
||||
)
|
||||
emit_endpoint_deprecation_notice(self, new_path=new_path)
|
||||
|
||||
ok, msg = await self._callback()
|
||||
if ok:
|
||||
self.write(msg)
|
||||
self.set_status(200)
|
||||
|
||||
# Tornado will set the _streamlit_xsrf cookie automatically for the page on
|
||||
# request for the document. However, if the server is reset and
|
||||
# server.enableXsrfProtection is updated, the browser does not reload the document.
|
||||
# Manually setting the cookie on /healthz since it is pinged when the
|
||||
# browser is disconnected from the server.
|
||||
if is_xsrf_enabled():
|
||||
cookie_kwargs = self.settings.get("xsrf_cookie_kwargs", {})
|
||||
self.set_cookie(
|
||||
self.settings.get("xsrf_cookie_name", "_streamlit_xsrf"),
|
||||
self.xsrf_token,
|
||||
**cookie_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
# 503 = SERVICE_UNAVAILABLE
|
||||
self.set_status(503)
|
||||
self.write(msg)
|
||||
|
||||
|
||||
_DEFAULT_ALLOWED_MESSAGE_ORIGINS = [
|
||||
# Community-cloud related domains.
|
||||
# We can remove these in the future if community cloud
|
||||
# provides those domains via the host-config endpoint.
|
||||
"https://devel.streamlit.test",
|
||||
"https://*.streamlit.apptest",
|
||||
"https://*.streamlitapp.test",
|
||||
"https://*.streamlitapp.com",
|
||||
"https://share.streamlit.io",
|
||||
"https://share-demo.streamlit.io",
|
||||
"https://share-head.streamlit.io",
|
||||
"https://share-staging.streamlit.io",
|
||||
"https://*.demo.streamlit.run",
|
||||
"https://*.head.streamlit.run",
|
||||
"https://*.staging.streamlit.run",
|
||||
"https://*.streamlit.run",
|
||||
"https://*.demo.streamlit.app",
|
||||
"https://*.head.streamlit.app",
|
||||
"https://*.staging.streamlit.app",
|
||||
"https://*.streamlit.app",
|
||||
]
|
||||
|
||||
|
||||
class HostConfigHandler(_SpecialRequestHandler):
|
||||
def initialize(self):
|
||||
# Make a copy of the allowedOrigins list, since we might modify it later:
|
||||
self._allowed_origins = _DEFAULT_ALLOWED_MESSAGE_ORIGINS.copy()
|
||||
|
||||
if (
|
||||
config.get_option("global.developmentMode")
|
||||
and "http://localhost" not in self._allowed_origins
|
||||
):
|
||||
# Allow messages from localhost in dev mode for testing of host <-> guest communication
|
||||
self._allowed_origins.append("http://localhost")
|
||||
|
||||
async def get(self) -> None:
|
||||
self.write(
|
||||
{
|
||||
"allowedOrigins": self._allowed_origins,
|
||||
"useExternalAuthToken": False,
|
||||
# Default host configuration settings.
|
||||
"enableCustomParentMessages": False,
|
||||
"enforceDownloadInNewTab": False,
|
||||
"metricsUrl": "",
|
||||
"blockErrorDialogs": False,
|
||||
}
|
||||
)
|
||||
self.set_status(200)
|
||||
|
||||
|
||||
class MessageCacheHandler(tornado.web.RequestHandler):
|
||||
"""Returns ForwardMsgs from our MessageCache."""
|
||||
|
||||
def initialize(self, cache):
|
||||
"""Initializes the handler.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cache : MessageCache
|
||||
|
||||
"""
|
||||
self._cache = cache
|
||||
|
||||
def set_default_headers(self):
|
||||
if allow_cross_origin_requests():
|
||||
self.set_header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
def get(self):
|
||||
msg_hash = self.get_argument("hash", None)
|
||||
if not config.get_option("global.storeCachedForwardMessagesInMemory"):
|
||||
# We use rare status code here, to distinguish between normal 404s.
|
||||
self.set_status(418)
|
||||
self.finish()
|
||||
return
|
||||
if msg_hash is None:
|
||||
# Hash is missing! This is a malformed request.
|
||||
_LOGGER.error(
|
||||
"HTTP request for cached message is missing the hash attribute."
|
||||
)
|
||||
self.set_status(404)
|
||||
raise tornado.web.Finish()
|
||||
|
||||
message = self._cache.get_message(msg_hash)
|
||||
if message is None:
|
||||
# Message not in our cache.
|
||||
_LOGGER.error(
|
||||
"HTTP request for cached message could not be fulfilled. "
|
||||
"No such message"
|
||||
)
|
||||
self.set_status(404)
|
||||
raise tornado.web.Finish()
|
||||
|
||||
_LOGGER.debug("MessageCache HIT")
|
||||
msg_str = serialize_forward_msg(message)
|
||||
self.set_header("Content-Type", "application/octet-stream")
|
||||
self.write(msg_str)
|
||||
self.set_status(200)
|
||||
|
||||
def options(self):
|
||||
"""/OPTIONS handler for preflight CORS checks."""
|
||||
self.set_status(204)
|
||||
self.finish()
|
||||
@@ -0,0 +1,479 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Final
|
||||
|
||||
import tornado.concurrent
|
||||
import tornado.locks
|
||||
import tornado.netutil
|
||||
import tornado.web
|
||||
import tornado.websocket
|
||||
from tornado.httpserver import HTTPServer
|
||||
|
||||
from streamlit import cli_util, config, file_util, util
|
||||
from streamlit.auth_util import is_authlib_installed
|
||||
from streamlit.config_option import ConfigOption
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime import Runtime, RuntimeConfig, RuntimeState
|
||||
from streamlit.runtime.memory_media_file_storage import MemoryMediaFileStorage
|
||||
from streamlit.runtime.memory_session_storage import MemorySessionStorage
|
||||
from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
|
||||
from streamlit.runtime.runtime_util import get_max_message_size_bytes
|
||||
from streamlit.web.cache_storage_manager_config import (
|
||||
create_default_cache_storage_manager,
|
||||
)
|
||||
from streamlit.web.server.app_static_file_handler import AppStaticFileHandler
|
||||
from streamlit.web.server.browser_websocket_handler import BrowserWebSocketHandler
|
||||
from streamlit.web.server.component_request_handler import ComponentRequestHandler
|
||||
from streamlit.web.server.media_file_handler import MediaFileHandler
|
||||
from streamlit.web.server.routes import (
|
||||
AddSlashHandler,
|
||||
HealthHandler,
|
||||
HostConfigHandler,
|
||||
MessageCacheHandler,
|
||||
RemoveSlashHandler,
|
||||
StaticFileHandler,
|
||||
)
|
||||
from streamlit.web.server.server_util import (
|
||||
DEVELOPMENT_PORT,
|
||||
get_cookie_secret,
|
||||
is_xsrf_enabled,
|
||||
make_url_path_regex,
|
||||
)
|
||||
from streamlit.web.server.stats_request_handler import StatsRequestHandler
|
||||
from streamlit.web.server.upload_file_request_handler import UploadFileRequestHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable
|
||||
from ssl import SSLContext
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
TORNADO_SETTINGS = {
|
||||
# Gzip HTTP responses.
|
||||
"compress_response": True,
|
||||
# Ping every 1s to keep WS alive.
|
||||
# 2021.06.22: this value was previously 20s, and was causing
|
||||
# connection instability for a small number of users. This smaller
|
||||
# ping_interval fixes that instability.
|
||||
# https://github.com/streamlit/streamlit/issues/3196
|
||||
"websocket_ping_interval": 1,
|
||||
# If we don't get a ping response within 30s, the connection
|
||||
# is timed out.
|
||||
"websocket_ping_timeout": 30,
|
||||
"xsrf_cookie_name": "_streamlit_xsrf",
|
||||
}
|
||||
|
||||
# When server.port is not available it will look for the next available port
|
||||
# up to MAX_PORT_SEARCH_RETRIES.
|
||||
MAX_PORT_SEARCH_RETRIES: Final = 100
|
||||
|
||||
# When server.address starts with this prefix, the server will bind
|
||||
# to an unix socket.
|
||||
UNIX_SOCKET_PREFIX: Final = "unix://"
|
||||
|
||||
MEDIA_ENDPOINT: Final = "/media"
|
||||
UPLOAD_FILE_ENDPOINT: Final = "/_stcore/upload_file"
|
||||
STREAM_ENDPOINT: Final = r"_stcore/stream"
|
||||
METRIC_ENDPOINT: Final = r"(?:st-metrics|_stcore/metrics)"
|
||||
MESSAGE_ENDPOINT: Final = r"_stcore/message"
|
||||
NEW_HEALTH_ENDPOINT: Final = "_stcore/health"
|
||||
HEALTH_ENDPOINT: Final = rf"(?:healthz|{NEW_HEALTH_ENDPOINT})"
|
||||
HOST_CONFIG_ENDPOINT: Final = r"_stcore/host-config"
|
||||
SCRIPT_HEALTH_CHECK_ENDPOINT: Final = (
|
||||
r"(?:script-health-check|_stcore/script-health-check)"
|
||||
)
|
||||
|
||||
OAUTH2_CALLBACK_ENDPOINT: Final = "/oauth2callback"
|
||||
AUTH_LOGIN_ENDPOINT: Final = "/auth/login"
|
||||
AUTH_LOGOUT_ENDPOINT: Final = "/auth/logout"
|
||||
|
||||
|
||||
class RetriesExceeded(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def server_port_is_manually_set() -> bool:
|
||||
return config.is_manually_set("server.port")
|
||||
|
||||
|
||||
def server_address_is_unix_socket() -> bool:
|
||||
address = config.get_option("server.address")
|
||||
return address is not None and address.startswith(UNIX_SOCKET_PREFIX)
|
||||
|
||||
|
||||
def start_listening(app: tornado.web.Application) -> None:
|
||||
"""Makes the server start listening at the configured port.
|
||||
|
||||
In case the port is already taken it tries listening to the next available
|
||||
port. It will error after MAX_PORT_SEARCH_RETRIES attempts.
|
||||
|
||||
"""
|
||||
cert_file = config.get_option("server.sslCertFile")
|
||||
key_file = config.get_option("server.sslKeyFile")
|
||||
ssl_options = _get_ssl_options(cert_file, key_file)
|
||||
|
||||
http_server = HTTPServer(
|
||||
app,
|
||||
max_buffer_size=config.get_option("server.maxUploadSize") * 1024 * 1024,
|
||||
ssl_options=ssl_options,
|
||||
)
|
||||
|
||||
if server_address_is_unix_socket():
|
||||
start_listening_unix_socket(http_server)
|
||||
else:
|
||||
start_listening_tcp_socket(http_server)
|
||||
|
||||
|
||||
def _get_ssl_options(cert_file: str | None, key_file: str | None) -> SSLContext | None:
|
||||
if bool(cert_file) != bool(key_file):
|
||||
_LOGGER.error(
|
||||
"Options 'server.sslCertFile' and 'server.sslKeyFile' must "
|
||||
"be set together. Set missing options or delete existing options."
|
||||
)
|
||||
sys.exit(1)
|
||||
if cert_file and key_file:
|
||||
# ssl_ctx.load_cert_chain raise exception as below, but it is not
|
||||
# sufficiently user-friendly
|
||||
# FileNotFoundError: [Errno 2] No such file or directory
|
||||
if not Path(cert_file).exists():
|
||||
_LOGGER.error("Cert file '%s' does not exist.", cert_file)
|
||||
sys.exit(1)
|
||||
if not Path(key_file).exists():
|
||||
_LOGGER.error("Key file '%s' does not exist.", key_file)
|
||||
sys.exit(1)
|
||||
|
||||
import ssl
|
||||
|
||||
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
# When the SSL certificate fails to load, an exception is raised as below,
|
||||
# but it is not sufficiently user-friendly.
|
||||
# ssl.SSLError: [SSL] PEM lib (_ssl.c:4067)
|
||||
try:
|
||||
ssl_ctx.load_cert_chain(cert_file, key_file)
|
||||
except ssl.SSLError:
|
||||
_LOGGER.error(
|
||||
"Failed to load SSL certificate. Make sure "
|
||||
"cert file '%s' and key file '%s' are correct.",
|
||||
cert_file,
|
||||
key_file,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
return ssl_ctx
|
||||
return None
|
||||
|
||||
|
||||
def start_listening_unix_socket(http_server: HTTPServer) -> None:
|
||||
address = config.get_option("server.address")
|
||||
file_name = os.path.expanduser(address[len(UNIX_SOCKET_PREFIX) :])
|
||||
|
||||
unix_socket = tornado.netutil.bind_unix_socket(file_name)
|
||||
http_server.add_socket(unix_socket)
|
||||
|
||||
|
||||
def start_listening_tcp_socket(http_server: HTTPServer) -> None:
|
||||
call_count = 0
|
||||
|
||||
port = None
|
||||
while call_count < MAX_PORT_SEARCH_RETRIES:
|
||||
address = config.get_option("server.address")
|
||||
port = config.get_option("server.port")
|
||||
|
||||
if int(port) == DEVELOPMENT_PORT:
|
||||
_LOGGER.warning(
|
||||
"Port %s is reserved for internal development. "
|
||||
"It is strongly recommended to select an alternative port "
|
||||
"for `server.port`.",
|
||||
DEVELOPMENT_PORT,
|
||||
)
|
||||
|
||||
try:
|
||||
http_server.listen(port, address)
|
||||
break # It worked! So let's break out of the loop.
|
||||
|
||||
except OSError as e:
|
||||
if e.errno == errno.EADDRINUSE:
|
||||
if server_port_is_manually_set():
|
||||
_LOGGER.error("Port %s is already in use", port)
|
||||
sys.exit(1)
|
||||
else:
|
||||
_LOGGER.debug(
|
||||
"Port %s already in use, trying to use the next one.", port
|
||||
)
|
||||
port += 1
|
||||
# Don't use the development port here:
|
||||
if port == DEVELOPMENT_PORT:
|
||||
port += 1
|
||||
|
||||
config.set_option(
|
||||
"server.port", port, ConfigOption.STREAMLIT_DEFINITION
|
||||
)
|
||||
call_count += 1
|
||||
else:
|
||||
raise
|
||||
|
||||
if call_count >= MAX_PORT_SEARCH_RETRIES:
|
||||
raise RetriesExceeded(
|
||||
f"Cannot start Streamlit server. Port {port} is already in use, and "
|
||||
f"Streamlit was unable to find a free port after {MAX_PORT_SEARCH_RETRIES} attempts.",
|
||||
)
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(self, main_script_path: str, is_hello: bool):
|
||||
"""Create the server. It won't be started yet."""
|
||||
_set_tornado_log_levels()
|
||||
self.initialize_mimetypes()
|
||||
|
||||
self._main_script_path = main_script_path
|
||||
|
||||
# Initialize MediaFileStorage and its associated endpoint
|
||||
media_file_storage = MemoryMediaFileStorage(MEDIA_ENDPOINT)
|
||||
MediaFileHandler.initialize_storage(media_file_storage)
|
||||
|
||||
uploaded_file_mgr = MemoryUploadedFileManager(UPLOAD_FILE_ENDPOINT)
|
||||
|
||||
self._runtime = Runtime(
|
||||
RuntimeConfig(
|
||||
script_path=main_script_path,
|
||||
command_line=None,
|
||||
media_file_storage=media_file_storage,
|
||||
uploaded_file_manager=uploaded_file_mgr,
|
||||
cache_storage_manager=create_default_cache_storage_manager(),
|
||||
is_hello=is_hello,
|
||||
session_storage=MemorySessionStorage(
|
||||
ttl_seconds=config.get_option("server.disconnectedSessionTTL")
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
self._runtime.stats_mgr.register_provider(media_file_storage)
|
||||
|
||||
@classmethod
|
||||
def initialize_mimetypes(cls) -> None:
|
||||
"""Ensures that common mime-types are robust against system misconfiguration."""
|
||||
mimetypes.add_type("text/html", ".html")
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
mimetypes.add_type("image/webp", ".webp")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
@property
|
||||
def main_script_path(self) -> str:
|
||||
return self._main_script_path
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the server.
|
||||
|
||||
When this returns, Streamlit is ready to accept new sessions.
|
||||
"""
|
||||
|
||||
_LOGGER.debug("Starting server...")
|
||||
|
||||
app = self._create_app()
|
||||
start_listening(app)
|
||||
|
||||
port = config.get_option("server.port")
|
||||
_LOGGER.debug("Server started on port %s", port)
|
||||
|
||||
await self._runtime.start()
|
||||
|
||||
@property
|
||||
def stopped(self) -> Awaitable[None]:
|
||||
"""A Future that completes when the Server's run loop has exited."""
|
||||
return self._runtime.stopped
|
||||
|
||||
def _create_app(self) -> tornado.web.Application:
|
||||
"""Create our tornado web app."""
|
||||
base = config.get_option("server.baseUrlPath")
|
||||
|
||||
routes: list[Any] = [
|
||||
(
|
||||
make_url_path_regex(base, STREAM_ENDPOINT),
|
||||
BrowserWebSocketHandler,
|
||||
{"runtime": self._runtime},
|
||||
),
|
||||
(
|
||||
make_url_path_regex(base, HEALTH_ENDPOINT),
|
||||
HealthHandler,
|
||||
{"callback": lambda: self._runtime.is_ready_for_browser_connection},
|
||||
),
|
||||
(
|
||||
make_url_path_regex(base, MESSAGE_ENDPOINT),
|
||||
MessageCacheHandler,
|
||||
{"cache": self._runtime.message_cache},
|
||||
),
|
||||
(
|
||||
make_url_path_regex(base, METRIC_ENDPOINT),
|
||||
StatsRequestHandler,
|
||||
{"stats_manager": self._runtime.stats_mgr},
|
||||
),
|
||||
(
|
||||
make_url_path_regex(base, HOST_CONFIG_ENDPOINT),
|
||||
HostConfigHandler,
|
||||
),
|
||||
(
|
||||
make_url_path_regex(
|
||||
base,
|
||||
rf"{UPLOAD_FILE_ENDPOINT}/(?P<session_id>[^/]+)/(?P<file_id>[^/]+)",
|
||||
),
|
||||
UploadFileRequestHandler,
|
||||
{
|
||||
"file_mgr": self._runtime.uploaded_file_mgr,
|
||||
"is_active_session": self._runtime.is_active_session,
|
||||
},
|
||||
),
|
||||
(
|
||||
make_url_path_regex(base, f"{MEDIA_ENDPOINT}/(.*)"),
|
||||
MediaFileHandler,
|
||||
{"path": ""},
|
||||
),
|
||||
(
|
||||
make_url_path_regex(base, "component/(.*)"),
|
||||
ComponentRequestHandler,
|
||||
{"registry": self._runtime.component_registry},
|
||||
),
|
||||
]
|
||||
|
||||
if config.get_option("server.scriptHealthCheckEnabled"):
|
||||
routes.extend(
|
||||
[
|
||||
(
|
||||
make_url_path_regex(base, SCRIPT_HEALTH_CHECK_ENDPOINT),
|
||||
HealthHandler,
|
||||
{
|
||||
"callback": lambda: self._runtime.does_script_run_without_error()
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if config.get_option("server.enableStaticServing"):
|
||||
routes.extend(
|
||||
[
|
||||
(
|
||||
make_url_path_regex(base, "app/static/(.*)"),
|
||||
AppStaticFileHandler,
|
||||
{"path": file_util.get_app_static_dir(self.main_script_path)},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if is_authlib_installed():
|
||||
from streamlit.web.server.oauth_authlib_routes import (
|
||||
AuthCallbackHandler,
|
||||
AuthLoginHandler,
|
||||
AuthLogoutHandler,
|
||||
)
|
||||
|
||||
routes.extend(
|
||||
[
|
||||
(
|
||||
make_url_path_regex(base, OAUTH2_CALLBACK_ENDPOINT),
|
||||
AuthCallbackHandler,
|
||||
{"base_url": base},
|
||||
),
|
||||
(
|
||||
make_url_path_regex(base, AUTH_LOGIN_ENDPOINT),
|
||||
AuthLoginHandler,
|
||||
{"base_url": base},
|
||||
),
|
||||
(
|
||||
make_url_path_regex(base, AUTH_LOGOUT_ENDPOINT),
|
||||
AuthLogoutHandler,
|
||||
{"base_url": base},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if config.get_option("global.developmentMode"):
|
||||
_LOGGER.debug("Serving static content from the Node dev server")
|
||||
else:
|
||||
static_path = file_util.get_static_dir()
|
||||
_LOGGER.debug("Serving static content from %s", static_path)
|
||||
|
||||
routes.extend(
|
||||
[
|
||||
(
|
||||
# We want to remove paths with a trailing slash, but if the path
|
||||
# starts with a double slash //, the redirect will point
|
||||
# the browser to the wrong host.
|
||||
make_url_path_regex(
|
||||
base, "(?!/)(.*)", trailing_slash="required"
|
||||
),
|
||||
RemoveSlashHandler,
|
||||
),
|
||||
(
|
||||
make_url_path_regex(base, "(.*)"),
|
||||
StaticFileHandler,
|
||||
{
|
||||
"path": "%s/" % static_path,
|
||||
"default_filename": "index.html",
|
||||
"reserved_paths": [
|
||||
# These paths are required for identifying
|
||||
# the base url path.
|
||||
NEW_HEALTH_ENDPOINT,
|
||||
HOST_CONFIG_ENDPOINT,
|
||||
],
|
||||
},
|
||||
),
|
||||
(
|
||||
make_url_path_regex(base, trailing_slash="prohibited"),
|
||||
AddSlashHandler,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
return tornado.web.Application(
|
||||
routes,
|
||||
cookie_secret=get_cookie_secret(),
|
||||
xsrf_cookies=is_xsrf_enabled(),
|
||||
# Set the websocket message size. The default value is too low.
|
||||
websocket_max_message_size=get_max_message_size_bytes(),
|
||||
**TORNADO_SETTINGS, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
@property
|
||||
def browser_is_connected(self) -> bool:
|
||||
return self._runtime.state == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED
|
||||
|
||||
@property
|
||||
def is_running_hello(self) -> bool:
|
||||
from streamlit.hello import streamlit_app
|
||||
|
||||
return self._main_script_path == streamlit_app.__file__
|
||||
|
||||
def stop(self) -> None:
|
||||
cli_util.print_to_cli(" Stopping...", fg="blue")
|
||||
self._runtime.stop()
|
||||
|
||||
|
||||
def _set_tornado_log_levels() -> None:
|
||||
if not config.get_option("global.developmentMode"):
|
||||
# Hide logs unless they're super important.
|
||||
# Example of stuff we don't care about: 404 about .js.map files.
|
||||
logging.getLogger("tornado.access").setLevel(logging.ERROR)
|
||||
logging.getLogger("tornado.application").setLevel(logging.ERROR)
|
||||
logging.getLogger("tornado.general").setLevel(logging.ERROR)
|
||||
@@ -0,0 +1,159 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Server related utility functions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Final, Literal
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from streamlit import config, net_util, url_util
|
||||
from streamlit.runtime.secrets import secrets_singleton
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tornado.web import RequestHandler
|
||||
|
||||
# The port reserved for internal development.
|
||||
DEVELOPMENT_PORT: Final = 3000
|
||||
|
||||
AUTH_COOKIE_NAME: Final = "_streamlit_user"
|
||||
|
||||
|
||||
def is_url_from_allowed_origins(url: str) -> bool:
|
||||
"""Return True if URL is from allowed origins (for CORS purpose).
|
||||
|
||||
Allowed origins:
|
||||
1. localhost
|
||||
2. The internal and external IP addresses of the machine where this
|
||||
function was called from.
|
||||
|
||||
If `server.enableCORS` is False, this allows all origins.
|
||||
"""
|
||||
if not config.get_option("server.enableCORS"):
|
||||
# Allow everything when CORS is disabled.
|
||||
return True
|
||||
|
||||
hostname = url_util.get_hostname(url)
|
||||
|
||||
allowed_domains = [ # List[Union[str, Callable[[], Optional[str]]]]
|
||||
# Check localhost first.
|
||||
"localhost",
|
||||
"0.0.0.0",
|
||||
"127.0.0.1",
|
||||
# Try to avoid making unnecessary HTTP requests by checking if the user
|
||||
# manually specified a server address.
|
||||
_get_server_address_if_manually_set,
|
||||
# Then try the options that depend on HTTP requests or opening sockets.
|
||||
net_util.get_internal_ip,
|
||||
net_util.get_external_ip,
|
||||
]
|
||||
|
||||
for allowed_domain in allowed_domains:
|
||||
if callable(allowed_domain):
|
||||
allowed_domain = allowed_domain()
|
||||
|
||||
if allowed_domain is None:
|
||||
continue
|
||||
|
||||
if hostname == allowed_domain:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_cookie_secret() -> str:
|
||||
"""Get the cookie secret.
|
||||
|
||||
If the user has not set a cookie secret, we generate a random one.
|
||||
"""
|
||||
cookie_secret: str = config.get_option("server.cookieSecret")
|
||||
if secrets_singleton.load_if_toml_exists():
|
||||
auth_section = secrets_singleton.get("auth")
|
||||
if auth_section:
|
||||
cookie_secret = auth_section.get("cookie_secret", cookie_secret)
|
||||
return cookie_secret
|
||||
|
||||
|
||||
def is_xsrf_enabled():
|
||||
csrf_enabled = config.get_option("server.enableXsrfProtection")
|
||||
if not csrf_enabled and secrets_singleton.load_if_toml_exists():
|
||||
auth_section = secrets_singleton.get("auth", None)
|
||||
csrf_enabled = csrf_enabled or auth_section is not None
|
||||
return csrf_enabled
|
||||
|
||||
|
||||
def _get_server_address_if_manually_set() -> str | None:
|
||||
if config.is_manually_set("browser.serverAddress"):
|
||||
return url_util.get_hostname(config.get_option("browser.serverAddress"))
|
||||
return None
|
||||
|
||||
|
||||
def make_url_path_regex(
|
||||
*path, trailing_slash: Literal["optional", "required", "prohibited"] = "optional"
|
||||
) -> str:
|
||||
"""Get a regex of the form ^/foo/bar/baz/?$ for a path (foo, bar, baz)."""
|
||||
path = [x.strip("/") for x in path if x] # Filter out falsely components.
|
||||
path_format = r"^/%s$"
|
||||
if trailing_slash == "optional":
|
||||
path_format = r"^/%s/?$"
|
||||
elif trailing_slash == "required":
|
||||
path_format = r"^/%s/$"
|
||||
|
||||
return path_format % "/".join(path)
|
||||
|
||||
|
||||
def get_url(host_ip: str) -> str:
|
||||
"""Get the URL for any app served at the given host_ip.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
host_ip : str
|
||||
The IP address of the machine that is running the Streamlit Server.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The URL.
|
||||
"""
|
||||
protocol = "https" if config.get_option("server.sslCertFile") else "http"
|
||||
|
||||
port = _get_browser_address_bar_port()
|
||||
base_path = config.get_option("server.baseUrlPath").strip("/")
|
||||
|
||||
if base_path:
|
||||
base_path = "/" + base_path
|
||||
|
||||
host_ip = host_ip.strip("/")
|
||||
return f"{protocol}://{host_ip}:{port}{base_path}"
|
||||
|
||||
|
||||
def _get_browser_address_bar_port() -> int:
|
||||
"""Get the app URL that will be shown in the browser's address bar.
|
||||
|
||||
That is, this is the port where static assets will be served from. In dev,
|
||||
this is different from the URL that will be used to connect to the
|
||||
server-browser websocket.
|
||||
|
||||
"""
|
||||
if config.get_option("global.developmentMode"):
|
||||
return DEVELOPMENT_PORT
|
||||
return int(config.get_option("browser.serverPort"))
|
||||
|
||||
|
||||
def emit_endpoint_deprecation_notice(handler: RequestHandler, new_path: str) -> None:
|
||||
"""Emits the warning about deprecation of HTTP endpoint in the HTTP header."""
|
||||
handler.set_header("Deprecation", True)
|
||||
new_url = urljoin(f"{handler.request.protocol}://{handler.request.host}", new_path)
|
||||
handler.set_header("Link", f'<{new_url}>; rel="alternate"')
|
||||
@@ -0,0 +1,95 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import tornado.web
|
||||
|
||||
from streamlit.web.server import allow_cross_origin_requests
|
||||
from streamlit.web.server.server_util import emit_endpoint_deprecation_notice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.proto.openmetrics_data_model_pb2 import MetricSet as MetricSetProto
|
||||
from streamlit.runtime.stats import CacheStat, StatsManager
|
||||
|
||||
|
||||
class StatsRequestHandler(tornado.web.RequestHandler):
|
||||
def initialize(self, stats_manager: StatsManager) -> None:
|
||||
self._manager = stats_manager
|
||||
|
||||
def set_default_headers(self):
|
||||
if allow_cross_origin_requests():
|
||||
self.set_header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
def options(self):
|
||||
"""/OPTIONS handler for preflight CORS checks."""
|
||||
self.set_status(204)
|
||||
self.finish()
|
||||
|
||||
def get(self) -> None:
|
||||
if self.request.uri and "_stcore/" not in self.request.uri:
|
||||
emit_endpoint_deprecation_notice(self, new_path="/_stcore/metrics")
|
||||
|
||||
stats = self._manager.get_stats()
|
||||
|
||||
# If the request asked for protobuf output, we return a serialized
|
||||
# protobuf. Else we return text.
|
||||
if "application/x-protobuf" in self.request.headers.get_list("Accept"):
|
||||
self.write(self._stats_to_proto(stats).SerializeToString())
|
||||
self.set_header("Content-Type", "application/x-protobuf")
|
||||
self.set_status(200)
|
||||
else:
|
||||
self.write(self._stats_to_text(self._manager.get_stats()))
|
||||
self.set_header("Content-Type", "application/openmetrics-text")
|
||||
self.set_status(200)
|
||||
|
||||
@staticmethod
|
||||
def _stats_to_text(stats: list[CacheStat]) -> str:
|
||||
metric_type = "# TYPE cache_memory_bytes gauge"
|
||||
metric_unit = "# UNIT cache_memory_bytes bytes"
|
||||
metric_help = "# HELP Total memory consumed by a cache."
|
||||
openmetrics_eof = "# EOF\n"
|
||||
|
||||
# Format: header, stats, EOF
|
||||
result = [metric_type, metric_unit, metric_help]
|
||||
result.extend(stat.to_metric_str() for stat in stats)
|
||||
result.append(openmetrics_eof)
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
@staticmethod
|
||||
def _stats_to_proto(stats: list[CacheStat]) -> MetricSetProto:
|
||||
# Lazy load the import of this proto message for better performance:
|
||||
from streamlit.proto.openmetrics_data_model_pb2 import GAUGE
|
||||
from streamlit.proto.openmetrics_data_model_pb2 import (
|
||||
MetricSet as MetricSetProto,
|
||||
)
|
||||
|
||||
metric_set = MetricSetProto()
|
||||
|
||||
metric_family = metric_set.metric_families.add()
|
||||
metric_family.name = "cache_memory_bytes"
|
||||
metric_family.type = GAUGE
|
||||
metric_family.unit = "bytes"
|
||||
metric_family.help = "Total memory consumed by a cache."
|
||||
|
||||
for stat in stats:
|
||||
metric_proto = metric_family.metrics.add()
|
||||
stat.marshall_metric_proto(metric_proto)
|
||||
|
||||
metric_set = MetricSetProto()
|
||||
metric_set.metric_families.append(metric_family)
|
||||
return metric_set
|
||||
@@ -0,0 +1,137 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
import tornado.httputil
|
||||
import tornado.web
|
||||
|
||||
from streamlit import config
|
||||
from streamlit.runtime.uploaded_file_manager import UploadedFileRec
|
||||
from streamlit.web.server import routes, server_util
|
||||
from streamlit.web.server.server_util import is_xsrf_enabled
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
|
||||
|
||||
|
||||
class UploadFileRequestHandler(tornado.web.RequestHandler):
|
||||
"""Implements the POST /upload_file endpoint."""
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
file_mgr: MemoryUploadedFileManager,
|
||||
is_active_session: Callable[[str], bool],
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
file_mgr : UploadedFileManager
|
||||
The server's singleton UploadedFileManager. All file uploads
|
||||
go here.
|
||||
is_active_session:
|
||||
A function that returns true if a session_id belongs to an active
|
||||
session.
|
||||
"""
|
||||
self._file_mgr = file_mgr
|
||||
self._is_active_session = is_active_session
|
||||
|
||||
def set_default_headers(self):
|
||||
self.set_header("Access-Control-Allow-Methods", "PUT, OPTIONS, DELETE")
|
||||
self.set_header("Access-Control-Allow-Headers", "Content-Type")
|
||||
if is_xsrf_enabled():
|
||||
self.set_header(
|
||||
"Access-Control-Allow-Origin",
|
||||
server_util.get_url(config.get_option("browser.serverAddress")),
|
||||
)
|
||||
self.set_header("Access-Control-Allow-Headers", "X-Xsrftoken, Content-Type")
|
||||
self.set_header("Vary", "Origin")
|
||||
self.set_header("Access-Control-Allow-Credentials", "true")
|
||||
elif routes.allow_cross_origin_requests():
|
||||
self.set_header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
def options(self, **kwargs):
|
||||
"""/OPTIONS handler for preflight CORS checks.
|
||||
|
||||
When a browser is making a CORS request, it may sometimes first
|
||||
send an OPTIONS request, to check whether the server understands the
|
||||
CORS protocol. This is optional, and doesn't happen for every request
|
||||
or in every browser. If an OPTIONS request does get sent, and is not
|
||||
then handled by the server, the browser will fail the underlying
|
||||
request.
|
||||
|
||||
The proper way to handle this is to send a 204 response ("no content")
|
||||
with the CORS headers attached. (These headers are automatically added
|
||||
to every outgoing response, including OPTIONS responses,
|
||||
via set_default_headers().)
|
||||
|
||||
See https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
|
||||
"""
|
||||
self.set_status(204)
|
||||
self.finish()
|
||||
|
||||
def put(self, **kwargs):
|
||||
"""Receive an uploaded file and add it to our UploadedFileManager."""
|
||||
|
||||
args: dict[str, list[bytes]] = {}
|
||||
files: dict[str, list[Any]] = {}
|
||||
|
||||
session_id = self.path_kwargs["session_id"]
|
||||
file_id = self.path_kwargs["file_id"]
|
||||
|
||||
tornado.httputil.parse_body_arguments(
|
||||
content_type=self.request.headers["Content-Type"],
|
||||
body=self.request.body,
|
||||
arguments=args,
|
||||
files=files,
|
||||
)
|
||||
|
||||
try:
|
||||
if not self._is_active_session(session_id):
|
||||
raise Exception("Invalid session_id")
|
||||
except Exception as e:
|
||||
self.send_error(400, reason=str(e))
|
||||
return
|
||||
|
||||
uploaded_files: list[UploadedFileRec] = []
|
||||
|
||||
for _, flist in files.items():
|
||||
for file in flist:
|
||||
uploaded_files.append(
|
||||
UploadedFileRec(
|
||||
file_id=file_id,
|
||||
name=file["filename"],
|
||||
type=file["content_type"],
|
||||
data=file["body"],
|
||||
)
|
||||
)
|
||||
|
||||
if len(uploaded_files) != 1:
|
||||
self.send_error(
|
||||
400, reason=f"Expected 1 file, but got {len(uploaded_files)}"
|
||||
)
|
||||
return
|
||||
|
||||
self._file_mgr.add_file(session_id=session_id, file=uploaded_files[0])
|
||||
self.set_status(204)
|
||||
|
||||
def delete(self, **kwargs):
|
||||
"""Delete file request handler."""
|
||||
session_id = self.path_kwargs["session_id"]
|
||||
file_id = self.path_kwargs["file_id"]
|
||||
|
||||
self._file_mgr.remove_file(session_id=session_id, file_id=file_id)
|
||||
self.set_status(204)
|
||||
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from streamlit import runtime
|
||||
from streamlit.deprecation_util import show_deprecation_warning
|
||||
from streamlit.runtime.metrics_util import gather_metrics
|
||||
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx
|
||||
from streamlit.web.server.browser_websocket_handler import BrowserWebSocketHandler
|
||||
|
||||
_GET_WEBSOCKET_HEADERS_DEPRECATE_MSG = (
|
||||
"The `_get_websocket_headers` function is deprecated and will be removed "
|
||||
"in a future version of Streamlit. Please use `st.context.headers` instead."
|
||||
)
|
||||
|
||||
|
||||
@gather_metrics("_get_websocket_headers")
|
||||
def _get_websocket_headers() -> dict[str, str] | None:
|
||||
"""Return a copy of the HTTP request headers for the current session's
|
||||
WebSocket connection. If there's no active session, return None instead.
|
||||
|
||||
Raise an error if the server is not running.
|
||||
|
||||
Note to the intrepid: this is an UNSUPPORTED, INTERNAL API. (We don't have plans
|
||||
to remove it without a replacement, but we don't consider this a production-ready
|
||||
function, and its signature may change without a deprecation warning.)
|
||||
"""
|
||||
|
||||
show_deprecation_warning(_GET_WEBSOCKET_HEADERS_DEPRECATE_MSG)
|
||||
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is None:
|
||||
return None
|
||||
|
||||
session_client = runtime.get_instance().get_client(ctx.session_id)
|
||||
if session_client is None:
|
||||
return None
|
||||
|
||||
if not isinstance(session_client, BrowserWebSocketHandler):
|
||||
raise RuntimeError(
|
||||
f"SessionClient is not a BrowserWebSocketHandler! ({session_client})"
|
||||
)
|
||||
|
||||
return dict(session_client.request.headers)
|
||||
Reference in New Issue
Block a user