Mise à jour de Monitor.py et autres scripts

This commit is contained in:
Debian
2025-07-23 10:46:27 +02:00
parent 7081418ce0
commit 7de3e0fb50
8604 changed files with 2789953 additions and 295 deletions

View File

@@ -0,0 +1,13 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,355 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import os
import signal
import sys
from typing import Any, Final
from streamlit import cli_util, config, env_util, file_util, net_util, secrets
from streamlit.config import CONFIG_FILENAMES
from streamlit.git_util import MIN_GIT_VERSION, GitRepo
from streamlit.logger import get_logger
from streamlit.watcher import report_watchdog_availability, watch_file
from streamlit.web.server import Server, server_address_is_unix_socket, server_util
_LOGGER: Final = get_logger(__name__)
# The maximum possible total size of a static directory.
# We agreed on these limitations for the initial release of static file sharing,
# based on security concerns from the SiS and Community Cloud teams
MAX_APP_STATIC_FOLDER_SIZE = 1 * 1024 * 1024 * 1024 # 1 GB
def _set_up_signal_handler(server: Server) -> None:
_LOGGER.debug("Setting up signal handler")
def signal_handler(signal_number, stack_frame):
# The server will shut down its threads and exit its loop.
server.stop()
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
if sys.platform == "win32":
signal.signal(signal.SIGBREAK, signal_handler)
else:
signal.signal(signal.SIGQUIT, signal_handler)
def _fix_sys_path(main_script_path: str) -> None:
"""Add the script's folder to the sys path.
Python normally does this automatically, but since we exec the script
ourselves we need to do it instead.
"""
sys.path.insert(0, os.path.dirname(main_script_path))
def _fix_tornado_crash() -> None:
"""Set default asyncio policy to be compatible with Tornado 6.
Tornado 6 (at least) is not compatible with the default
asyncio implementation on Windows. So here we
pick the older SelectorEventLoopPolicy when the OS is Windows
if the known-incompatible default policy is in use.
This has to happen as early as possible to make it a low priority and
overridable
See: https://github.com/tornadoweb/tornado/issues/2608
FIXME: if/when tornado supports the defaults in asyncio,
remove and bump tornado requirement for py38
"""
if env_util.IS_WINDOWS:
try:
from asyncio import ( # type: ignore[attr-defined]
WindowsProactorEventLoopPolicy,
WindowsSelectorEventLoopPolicy,
)
except ImportError:
pass
# Not affected
else:
if type(asyncio.get_event_loop_policy()) is WindowsProactorEventLoopPolicy:
# WindowsProactorEventLoopPolicy is not compatible with
# Tornado 6 fallback to the pre-3.8 default of Selector
asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
def _fix_sys_argv(main_script_path: str, args: list[str]) -> None:
"""sys.argv needs to exclude streamlit arguments and parameters
and be set to what a user's script may expect.
"""
import sys
sys.argv = [main_script_path] + list(args)
def _on_server_start(server: Server) -> None:
_maybe_print_old_git_warning(server.main_script_path)
_maybe_print_static_folder_warning(server.main_script_path)
_print_url(server.is_running_hello)
report_watchdog_availability()
# Load secrets.toml if it exists. If the file doesn't exist, this
# function will return without raising an exception. We catch any parse
# errors and display them here.
try:
secrets.load_if_toml_exists()
except Exception as ex:
_LOGGER.error("Failed to load secrets.toml file", exc_info=ex)
def maybe_open_browser():
if config.get_option("server.headless"):
# Don't open browser when in headless mode.
return
if config.is_manually_set("browser.serverAddress"):
addr = config.get_option("browser.serverAddress")
elif config.is_manually_set("server.address"):
if server_address_is_unix_socket():
# Don't open browser when server address is an unix socket
return
addr = config.get_option("server.address")
else:
addr = "localhost"
cli_util.open_browser(server_util.get_url(addr))
# Schedule the browser to open on the main thread.
asyncio.get_running_loop().call_soon(maybe_open_browser)
def _fix_pydeck_mapbox_api_warning() -> None:
"""Sets MAPBOX_API_KEY environment variable needed for PyDeck otherwise it
will throw an exception.
"""
os.environ["MAPBOX_API_KEY"] = config.get_option("mapbox.token")
def _maybe_print_static_folder_warning(main_script_path: str) -> None:
"""Prints a warning if the static folder is misconfigured."""
if config.get_option("server.enableStaticServing"):
static_folder_path = file_util.get_app_static_dir(main_script_path)
if not os.path.isdir(static_folder_path):
cli_util.print_to_cli(
f"WARNING: Static file serving is enabled, but no static folder found "
f"at {static_folder_path}. To disable static file serving, "
f"set server.enableStaticServing to false.",
fg="yellow",
)
else:
# Raise warning when static folder size is larger than 1 GB
static_folder_size = file_util.get_directory_size(static_folder_path)
if static_folder_size > MAX_APP_STATIC_FOLDER_SIZE:
config.set_option("server.enableStaticServing", False)
cli_util.print_to_cli(
"WARNING: Static folder size is larger than 1GB. "
"Static file serving has been disabled.",
fg="yellow",
)
def _print_url(is_running_hello: bool) -> None:
if is_running_hello:
title_message = "Welcome to Streamlit. Check out our demo in your browser."
else:
title_message = "You can now view your Streamlit app in your browser."
named_urls = []
if config.is_manually_set("browser.serverAddress"):
named_urls = [
("URL", server_util.get_url(config.get_option("browser.serverAddress")))
]
elif (
config.is_manually_set("server.address") and not server_address_is_unix_socket()
):
named_urls = [
("URL", server_util.get_url(config.get_option("server.address"))),
]
elif server_address_is_unix_socket():
named_urls = [
("Unix Socket", config.get_option("server.address")),
]
else:
named_urls = [
("Local URL", server_util.get_url("localhost")),
]
internal_ip = net_util.get_internal_ip()
if internal_ip:
named_urls.append(("Network URL", server_util.get_url(internal_ip)))
if config.get_option("server.headless"):
external_ip = net_util.get_external_ip()
if external_ip:
named_urls.append(("External URL", server_util.get_url(external_ip)))
cli_util.print_to_cli("")
cli_util.print_to_cli(" %s" % title_message, fg="blue", bold=True)
cli_util.print_to_cli("")
for url_name, url in named_urls:
cli_util.print_to_cli(f" {url_name}: ", nl=False, fg="blue")
cli_util.print_to_cli(url, bold=True)
cli_util.print_to_cli("")
if is_running_hello:
cli_util.print_to_cli(" Ready to create your own Python apps super quickly?")
cli_util.print_to_cli(" Head over to ", nl=False)
cli_util.print_to_cli("https://docs.streamlit.io", bold=True)
cli_util.print_to_cli("")
cli_util.print_to_cli(" May you create awesome apps!")
cli_util.print_to_cli("")
cli_util.print_to_cli("")
def _maybe_print_old_git_warning(main_script_path: str) -> None:
"""If our script is running in a Git repo, and we're running a very old
Git version, print a warning that Git integration will be unavailable.
"""
repo = GitRepo(main_script_path)
if (
not repo.is_valid()
and repo.git_version is not None
and repo.git_version < MIN_GIT_VERSION
):
git_version_string = ".".join(str(val) for val in repo.git_version)
min_version_string = ".".join(str(val) for val in MIN_GIT_VERSION)
cli_util.print_to_cli("")
cli_util.print_to_cli(" Git integration is disabled.", fg="yellow", bold=True)
cli_util.print_to_cli("")
cli_util.print_to_cli(
f" Streamlit requires Git {min_version_string} or later, "
f"but you have {git_version_string}.",
fg="yellow",
)
cli_util.print_to_cli(
" Git is used by Streamlit Cloud (https://streamlit.io/cloud).",
fg="yellow",
)
cli_util.print_to_cli(
" To enable this feature, please update Git.", fg="yellow"
)
def load_config_options(flag_options: dict[str, Any]) -> None:
"""Load config options from config.toml files, then overlay the ones set by
flag_options.
The "streamlit run" command supports passing Streamlit's config options
as flags. This function reads through the config options set via flag,
massages them, and passes them to get_config_options() so that they
overwrite config option defaults and those loaded from config.toml files.
Parameters
----------
flag_options : dict[str, Any]
A dict of config options where the keys are the CLI flag version of the
config option names.
"""
# We want to filter out two things: values that are None, and values that
# are empty tuples. The latter is a special case that indicates that the
# no values were provided, and the config should reset to the default
options_from_flags = {
name.replace("_", "."): val
for name, val in flag_options.items()
if val is not None and val != ()
}
# Force a reparse of config files (if they exist). The result is cached
# for future calls.
config.get_config_options(force_reparse=True, options_from_flags=options_from_flags)
def _install_config_watchers(flag_options: dict[str, Any]) -> None:
def on_config_changed(_path):
load_config_options(flag_options)
for filename in CONFIG_FILENAMES:
if os.path.exists(filename):
watch_file(filename, on_config_changed)
def run(
main_script_path: str,
is_hello: bool,
args: list[str],
flag_options: dict[str, Any],
*,
stop_immediately_for_testing: bool = False,
) -> None:
"""Run a script in a separate thread and start a server for the app.
This starts a blocking asyncio eventloop.
"""
_fix_sys_path(main_script_path)
_fix_tornado_crash()
_fix_sys_argv(main_script_path, args)
_fix_pydeck_mapbox_api_warning()
_install_config_watchers(flag_options)
# Create the server. It won't start running yet.
server = Server(main_script_path, is_hello)
async def run_server() -> None:
# Start the server
await server.start()
_on_server_start(server)
# Install a signal handler that will shut down the server
# and close all our threads
_set_up_signal_handler(server)
# return immediately if we're testing the server start
if stop_immediately_for_testing:
_LOGGER.debug("Stopping server immediately for testing")
server.stop()
# Wait until `Server.stop` is called, either by our signal handler, or
# by a debug websocket session.
await server.stopped
# Run the server. This function will not return until the server is shut down.
# FIX RuntimeError: asyncio.run() cannot be called from a running event loop on Python 3.10.16
# asyncio.run(run_server())
# Define a main function to handle the event loop logic
async def main():
await run_server()
try:
# Check if we're already in an event loop
if asyncio.get_running_loop().is_running():
# Use `asyncio.create_task` if we're in an async context
asyncio.create_task(main())
else:
# Otherwise, use `asyncio.run`
asyncio.run(main())
except RuntimeError:
# get_running_loop throws RuntimeError if no running event loop
asyncio.run(main())

View File

@@ -0,0 +1,38 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from streamlit.runtime.caching.storage.local_disk_cache_storage import (
LocalDiskCacheStorageManager,
)
if TYPE_CHECKING:
from streamlit.runtime.caching.storage import CacheStorageManager
def create_default_cache_storage_manager() -> CacheStorageManager:
"""
Get the cache storage manager.
It would be used both in server.py and in cli.py to have unified cache storage.
Returns
-------
CacheStorageManager
The cache storage manager.
"""
return LocalDiskCacheStorageManager()

View File

@@ -0,0 +1,411 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A script which is run when the Streamlit package is executed."""
from __future__ import annotations
import os
import sys
from typing import TYPE_CHECKING, Any, Callable, TypeVar
# We cannot lazy-load click here because its used via decorators.
import click
import streamlit.runtime.caching as caching
import streamlit.web.bootstrap as bootstrap
from streamlit import config as _config
from streamlit.runtime.credentials import Credentials, check_credentials
from streamlit.web.cache_storage_manager_config import (
create_default_cache_storage_manager,
)
if TYPE_CHECKING:
from streamlit.config_option import ConfigOption
ACCEPTED_FILE_EXTENSIONS = ("py", "py3")
LOG_LEVELS = ("error", "warning", "info", "debug")
def _convert_config_option_to_click_option(
config_option: ConfigOption,
) -> dict[str, Any]:
"""Composes given config option options as options for click lib."""
option = f"--{config_option.key}"
param = config_option.key.replace(".", "_")
description = config_option.description
if config_option.deprecated:
if description is None:
description = ""
description += (
f"\n {config_option.deprecation_text} - {config_option.expiration_date}"
)
return {
"param": param,
"description": description,
"type": config_option.type,
"option": option,
"envvar": config_option.env_var,
"multiple": config_option.multiple,
}
def _make_sensitive_option_callback(config_option: ConfigOption):
def callback(_ctx: click.Context, _param: click.Parameter, cli_value) -> None:
if cli_value is None:
return None
raise SystemExit(
f"Setting {config_option.key!r} option using the CLI flag is not allowed. "
f"Set this option in the configuration file or environment "
f"variable: {config_option.env_var!r}"
)
return callback
F = TypeVar("F", bound=Callable[..., Any])
def configurator_options(func: F) -> F:
"""Decorator that adds config param keys to click dynamically."""
for _, value in reversed(_config._config_options_template.items()):
parsed_parameter = _convert_config_option_to_click_option(value)
if value.sensitive:
# Display a warning if the user tries to set sensitive
# options using the CLI and exit with non-zero code.
click_option_kwargs = {
"expose_value": False,
"hidden": True,
"is_eager": True,
"callback": _make_sensitive_option_callback(value),
}
else:
click_option_kwargs = {
"show_envvar": True,
"envvar": parsed_parameter["envvar"],
}
config_option = click.option(
parsed_parameter["option"],
parsed_parameter["param"],
help=parsed_parameter["description"],
type=parsed_parameter["type"],
multiple=parsed_parameter["multiple"],
**click_option_kwargs,
)
func = config_option(func)
return func
def _download_remote(main_script_path: str, url_path: str) -> None:
"""Fetch remote file at url_path to main_script_path."""
import requests
with open(main_script_path, "wb") as fp:
try:
resp = requests.get(url_path)
resp.raise_for_status()
fp.write(resp.content)
except requests.exceptions.RequestException as e:
raise click.BadParameter(f"Unable to fetch {url_path}.\n{e}")
@click.group(context_settings={"auto_envvar_prefix": "STREAMLIT"})
@click.option("--log_level", show_default=True, type=click.Choice(LOG_LEVELS))
@click.version_option(prog_name="Streamlit")
def main(log_level="info"):
"""Try out a demo with:
$ streamlit hello
Or use the line below to run your own script:
$ streamlit run your_script.py
""" # noqa: D400
if log_level:
from streamlit.logger import get_logger
LOGGER = get_logger(__name__)
LOGGER.warning(
"Setting the log level using the --log_level flag is unsupported."
"\nUse the --logger.level flag (after your streamlit command) instead."
)
@main.command("help")
def help():
"""Print this help message."""
# We use _get_command_line_as_string to run some error checks but don't do
# anything with its return value.
_get_command_line_as_string()
assert len(sys.argv) == 2 # This is always true, but let's assert anyway.
# Pretend user typed 'streamlit --help' instead of 'streamlit help'.
sys.argv[1] = "--help"
main(prog_name="streamlit")
@main.command("version")
def main_version():
"""Print Streamlit's version number."""
# Pretend user typed 'streamlit --version' instead of 'streamlit version'
import sys
# We use _get_command_line_as_string to run some error checks but don't do
# anything with its return value.
_get_command_line_as_string()
assert len(sys.argv) == 2 # This is always true, but let's assert anyway.
sys.argv[1] = "--version"
main()
@main.command("docs")
def main_docs():
"""Show help in browser."""
click.echo("Showing help page in browser...")
from streamlit import cli_util
cli_util.open_browser("https://docs.streamlit.io")
@main.command("hello")
@configurator_options
def main_hello(**kwargs):
"""Runs the Hello World script."""
from streamlit.hello import streamlit_app
bootstrap.load_config_options(flag_options=kwargs)
filename = streamlit_app.__file__
_main_run(filename, flag_options=kwargs)
@main.command("run")
@configurator_options
@click.argument("target", required=True, envvar="STREAMLIT_RUN_TARGET")
@click.argument("args", nargs=-1)
def main_run(target: str, args=None, **kwargs):
"""Run a Python script, piping stderr to Streamlit.
The script can be local or it can be an url. In the latter case, Streamlit
will download the script to a temporary file and runs this file.
"""
from streamlit import url_util
bootstrap.load_config_options(flag_options=kwargs)
_, extension = os.path.splitext(target)
if extension[1:] not in ACCEPTED_FILE_EXTENSIONS:
if extension[1:] == "":
raise click.BadArgumentUsage(
"Streamlit requires raw Python (.py) files, but the provided file has no extension.\nFor more information, please see https://docs.streamlit.io"
)
else:
raise click.BadArgumentUsage(
f"Streamlit requires raw Python (.py) files, not {extension}.\nFor more information, please see https://docs.streamlit.io"
)
if url_util.is_url(target):
from streamlit.temporary_directory import TemporaryDirectory
with TemporaryDirectory() as temp_dir:
from urllib.parse import urlparse
path = urlparse(target).path
main_script_path = os.path.join(
temp_dir, path.strip("/").rsplit("/", 1)[-1]
)
# if this is a GitHub/Gist blob url, convert to a raw URL first.
target = url_util.process_gitblob_url(target)
_download_remote(main_script_path, target)
_main_run(main_script_path, args, flag_options=kwargs)
else:
if not os.path.exists(target):
raise click.BadParameter(f"File does not exist: {target}")
_main_run(target, args, flag_options=kwargs)
def _get_command_line_as_string() -> str | None:
import subprocess
parent = click.get_current_context().parent
if parent is None:
return None
if "streamlit.cli" in parent.command_path:
raise RuntimeError(
"Running streamlit via `python -m streamlit.cli <command>` is"
" unsupported. Please use `python -m streamlit <command>` instead."
)
cmd_line_as_list = [parent.command_path]
cmd_line_as_list.extend(sys.argv[1:])
return subprocess.list2cmdline(cmd_line_as_list)
def _main_run(
file,
args: list[str] | None = None,
flag_options: dict[str, Any] | None = None,
) -> None:
if args is None:
args = []
if flag_options is None:
flag_options = {}
is_hello = _get_command_line_as_string() == "streamlit hello"
check_credentials()
bootstrap.run(file, is_hello, args, flag_options)
# SUBCOMMAND: cache
@main.group("cache")
def cache():
"""Manage the Streamlit cache."""
pass
@cache.command("clear")
def cache_clear():
"""Clear st.cache_data and st.cache_resource caches."""
# in this `streamlit cache clear` cli command we cannot use the
# `cache_storage_manager from runtime (since runtime is not initialized)
# so we create a new cache_storage_manager instance that used in runtime,
# and call clear_all() method for it.
# This will not remove the in-memory cache.
cache_storage_manager = create_default_cache_storage_manager()
cache_storage_manager.clear_all()
caching.cache_resource.clear()
# SUBCOMMAND: config
@main.group("config")
def config():
"""Manage Streamlit's config settings."""
pass
@config.command("show")
@configurator_options
def config_show(**kwargs):
"""Show all of Streamlit's config settings."""
bootstrap.load_config_options(flag_options=kwargs)
_config.show_config()
# SUBCOMMAND: activate
@main.group("activate", invoke_without_command=True)
@click.pass_context
def activate(ctx):
"""Activate Streamlit by entering your email."""
if not ctx.invoked_subcommand:
Credentials.get_current().activate()
@activate.command("reset")
def activate_reset():
"""Reset Activation Credentials."""
Credentials.get_current().reset()
# SUBCOMMAND: test
@main.group("test", hidden=True)
def test():
"""Internal-only commands used for testing.
These commands are not included in the output of `streamlit help`.
"""
pass
@test.command("prog_name")
def test_prog_name():
"""Assert that the program name is set to `streamlit test`.
This is used by our cli-smoke-tests to verify that the program name is set
to `streamlit ...` whether the streamlit binary is invoked directly or via
`python -m streamlit ...`.
"""
# We use _get_command_line_as_string to run some error checks but don't do
# anything with its return value.
_get_command_line_as_string()
parent = click.get_current_context().parent
assert parent is not None
assert parent.command_path == "streamlit test"
@main.command("init")
@click.argument("directory", required=False)
def main_init(directory: str | None = None):
"""Initialize a new Streamlit project.
If DIRECTORY is specified, create it and initialize the project there.
Otherwise use the current directory.
"""
from pathlib import Path
project_dir = Path(directory) if directory else Path.cwd()
try:
project_dir.mkdir(exist_ok=True, parents=True)
except OSError as e:
raise click.ClickException(f"Failed to create directory: {e}")
# Create requirements.txt
(project_dir / "requirements.txt").write_text("streamlit\n")
# Create streamlit_app.py
(project_dir / "streamlit_app.py").write_text("""import streamlit as st
st.title("🎈 My new app")
st.write(
"Let's start building! For help and inspiration, head over to [docs.streamlit.io](https://docs.streamlit.io/)."
)
""")
rel_path_str = str(directory) if directory else "."
click.secho("✨ Created new Streamlit app in ", nl=False)
click.secho(f"{rel_path_str}", fg="blue")
click.echo("🚀 Run it with: ", nl=False)
click.secho(f"streamlit run {rel_path_str}/streamlit_app.py", fg="blue")
if click.confirm("❓ Run the app now?", default=True):
app_path = project_dir / "streamlit_app.py"
click.echo("\nStarting Streamlit...")
_main_run(str(app_path))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,26 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from streamlit.web.server.component_request_handler import ComponentRequestHandler
from streamlit.web.server.routes import allow_cross_origin_requests
from streamlit.web.server.server import Server, server_address_is_unix_socket
from streamlit.web.server.stats_request_handler import StatsRequestHandler
__all__ = [
"ComponentRequestHandler",
"allow_cross_origin_requests",
"Server",
"server_address_is_unix_socket",
"StatsRequestHandler",
]

View File

@@ -0,0 +1,93 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
from pathlib import Path
from typing import Final
import tornado.web
from streamlit.logger import get_logger
_LOGGER: Final = get_logger(__name__)
# We agreed on these limitations for the initial release of static file sharing,
# based on security concerns from the SiS and Community Cloud teams
# The maximum possible size of single serving static file.
MAX_APP_STATIC_FILE_SIZE = 200 * 1024 * 1024 # 200 MB
# The list of file extensions that we serve with the corresponding Content-Type header.
# All files with other extensions will be served with Content-Type: text/plain
SAFE_APP_STATIC_FILE_EXTENSIONS = (
# Common image types:
".jpg",
".jpeg",
".png",
".gif",
".webp",
# Common font types:
".otf",
".ttf",
".woff",
".woff2",
# Other types:
".pdf",
".xml",
".json",
)
class AppStaticFileHandler(tornado.web.StaticFileHandler):
def initialize(self, path: str, default_filename: str | None = None) -> None:
super().initialize(path, default_filename)
def validate_absolute_path(self, root: str, absolute_path: str) -> str | None:
full_path = os.path.abspath(absolute_path)
ret_val = super().validate_absolute_path(root, absolute_path)
if os.path.isdir(full_path):
# we don't want to serve directories, and serve only files
raise tornado.web.HTTPError(404)
if os.path.commonpath([full_path, root]) != root:
# Don't allow misbehaving clients to break out of the static files directory
_LOGGER.warning(
"Serving files outside of the static directory is not supported"
)
raise tornado.web.HTTPError(404)
if (
os.path.exists(full_path)
and os.path.getsize(full_path) > MAX_APP_STATIC_FILE_SIZE
):
raise tornado.web.HTTPError(
404,
"File is too large, its size should not exceed "
f"{MAX_APP_STATIC_FILE_SIZE} bytes",
reason="File is too large",
)
return ret_val
def set_default_headers(self):
# CORS protection is disabled because we need access to this endpoint
# from the inner iframe.
self.set_header("Access-Control-Allow-Origin", "*")
def set_extra_headers(self, path: str) -> None:
if Path(path).suffix not in SAFE_APP_STATIC_FILE_EXTENSIONS:
self.set_header("Content-Type", "text/plain")
self.set_header("X-Content-Type-Options", "nosniff")

View File

@@ -0,0 +1,60 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from authlib.integrations.base_client import ( # type: ignore[import-untyped]
FrameworkIntegration,
)
from streamlit.runtime.secrets import AttrDict
if TYPE_CHECKING:
from collections.abc import Sequence
from streamlit.web.server.oidc_mixin import TornadoOAuth
class TornadoIntegration(FrameworkIntegration): # type: ignore[misc]
def update_token(self, token, refresh_token=None, access_token=None):
"""We do not support access token refresh, since we obtain and operate only on
identity tokens. We override this method explicitly to implement all abstract
methods of base class.
"""
@staticmethod
def load_config(
oauth: TornadoOAuth, name: str, params: Sequence[str]
) -> dict[str, Any]:
"""Configure Authlib integration with provider parameters
specified in secrets.toml.
"""
# oauth.config here is an auth section from secrets.toml
# We parse it here to transform nested AttrDict (for client_kwargs value)
# to dict so Authlib can work with it under the hood.
if not oauth.config:
return {}
prepared_config = {}
for key in params:
value = oauth.config.get(name, {}).get(key, None)
if isinstance(value, AttrDict):
# We want to modify client_kwargs further after loading server metadata
value = value.to_dict()
if value is not None:
prepared_config[key] = value
return prepared_config

View File

@@ -0,0 +1,247 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import hmac
import json
from typing import TYPE_CHECKING, Any, Final
from urllib.parse import urlparse
import tornado.concurrent
import tornado.locks
import tornado.netutil
import tornado.web
import tornado.websocket
from tornado.escape import utf8
from tornado.websocket import WebSocketHandler
from streamlit import config
from streamlit.logger import get_logger
from streamlit.proto.BackMsg_pb2 import BackMsg
from streamlit.runtime import Runtime, SessionClient, SessionClientDisconnectedError
from streamlit.runtime.runtime_util import serialize_forward_msg
from streamlit.web.server.server_util import (
AUTH_COOKIE_NAME,
is_url_from_allowed_origins,
is_xsrf_enabled,
)
if TYPE_CHECKING:
from collections.abc import Awaitable
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
_LOGGER: Final = get_logger(__name__)
class BrowserWebSocketHandler(WebSocketHandler, SessionClient):
"""Handles a WebSocket connection from the browser."""
def initialize(self, runtime: Runtime) -> None:
self._runtime = runtime
self._session_id: str | None = None
# The XSRF cookie is normally set when xsrf_form_html is used, but in a
# pure-Javascript application that does not use any regular forms we just
# need to read the self.xsrf_token manually to set the cookie as a side
# effect. See https://www.tornadoweb.org/en/stable/guide/security.html#cross-site-request-forgery-protection
# for more details.
if is_xsrf_enabled():
_ = self.xsrf_token
def get_signed_cookie(
self,
name: str,
value: str | None = None,
max_age_days: float = 31,
min_version: int | None = None,
) -> bytes | None:
"""Get a signed cookie from the request. Added for compatibility with
Tornado < 6.3.0.
See release notes: https://www.tornadoweb.org/en/stable/releases/v6.3.0.html#deprecation-notices
"""
try:
return super().get_signed_cookie(name, value, max_age_days, min_version)
except AttributeError:
return super().get_secure_cookie(name, value, max_age_days, min_version)
def check_origin(self, origin: str) -> bool:
"""Set up CORS."""
return super().check_origin(origin) or is_url_from_allowed_origins(origin)
def _validate_xsrf_token(self, supplied_token: str) -> bool:
"""Inspired by tornado.web.RequestHandler.check_xsrf_cookie method,
to check the XSRF token passed in Websocket connection header.
"""
_, token, _ = self._decode_xsrf_token(supplied_token)
_, expected_token, _ = self._get_raw_xsrf_token()
decoded_token = utf8(token)
decoded_expected_token = utf8(expected_token)
if not decoded_token or not decoded_expected_token:
return False
return hmac.compare_digest(decoded_token, decoded_expected_token)
def _parse_user_cookie(self, raw_cookie_value: bytes) -> dict[str, Any]:
"""Process the user cookie and extract the user info after
validating the origin. Origin is validated for security reasons.
"""
cookie_value = json.loads(raw_cookie_value)
user_info = {}
cookie_value_origin = cookie_value.get("origin", None)
parsed_origin_from_header = urlparse(self.request.headers["Origin"])
expected_origin_value = (
parsed_origin_from_header.scheme + "://" + parsed_origin_from_header.netloc
)
if cookie_value_origin == expected_origin_value:
user_info["is_logged_in"] = cookie_value.get("is_logged_in", False)
del cookie_value["origin"]
del cookie_value["is_logged_in"]
user_info.update(cookie_value)
return user_info
def write_forward_msg(self, msg: ForwardMsg) -> None:
"""Send a ForwardMsg to the browser."""
try:
self.write_message(serialize_forward_msg(msg), binary=True)
except tornado.websocket.WebSocketClosedError as e:
raise SessionClientDisconnectedError from e
def select_subprotocol(self, subprotocols: list[str]) -> str | None:
"""Return the first subprotocol in the given list.
This method is used by Tornado to select a protocol when the
Sec-WebSocket-Protocol header is set in an HTTP Upgrade request.
NOTE: We repurpose the Sec-WebSocket-Protocol header here in a slightly
unfortunate (but necessary) way. The browser WebSocket API doesn't allow us to
set arbitrary HTTP headers, and this header is the only one where we have the
ability to set it to arbitrary values, so we use it to pass tokens (in this
case, the previous session ID to allow us to reconnect to it) from client to
server as the *third* value in the list.
The reason why the auth token is set as the third value is that:
- when Sec-WebSocket-Protocol is set, many clients expect the server to
respond with a selected subprotocol to use. We don't want that reply to be
the session token, so we by convention have the client always set the first
protocol to "streamlit" and select that.
- the second protocol in the list is reserved in some deployment environments
for an auth token that we currently don't use
"""
if subprotocols:
return subprotocols[0]
return None
def open(self, *args, **kwargs) -> Awaitable[None] | None:
user_info: dict[str, str | bool | None] = {}
existing_session_id = None
try:
ws_protocols = [
p.strip()
for p in self.request.headers["Sec-Websocket-Protocol"].split(",")
]
raw_cookie_value = self.get_signed_cookie(AUTH_COOKIE_NAME)
if is_xsrf_enabled() and raw_cookie_value:
csrf_protocol_value = ws_protocols[1]
if self._validate_xsrf_token(csrf_protocol_value):
user_info.update(self._parse_user_cookie(raw_cookie_value))
if len(ws_protocols) >= 3:
# See the NOTE in the docstring of the `select_subprotocol` method above
# for a detailed explanation of why this is done.
existing_session_id = ws_protocols[2]
except KeyError:
# Just let existing_session_id=None if we run into any error while trying to
# extract it from the Sec-Websocket-Protocol header.
pass
self._session_id = self._runtime.connect_session(
client=self,
user_info=user_info,
existing_session_id=existing_session_id,
)
return None
def on_close(self) -> None:
if not self._session_id:
return
self._runtime.disconnect_session(self._session_id)
self._session_id = None
def get_compression_options(self) -> dict[Any, Any] | None:
"""Enable WebSocket compression.
Returning an empty dict enables websocket compression. Returning
None disables it.
(See the docstring in the parent class.)
"""
if config.get_option("server.enableWebsocketCompression"):
return {}
return None
def on_message(self, payload: str | bytes) -> None:
if not self._session_id:
return
try:
if isinstance(payload, str):
# Sanity check. (The frontend should only be sending us bytes;
# Protobuf.ParseFromString does not accept str input.)
raise RuntimeError(
"WebSocket received an unexpected `str` message. "
"(We expect `bytes` only.)"
)
msg = BackMsg()
msg.ParseFromString(payload)
_LOGGER.debug("Received the following back message:\n%s", msg)
except Exception as ex:
_LOGGER.exception("Error deserializing back message")
self._runtime.handle_backmsg_deserialization_exception(self._session_id, ex)
return
# "debug_disconnect_websocket" and "debug_shutdown_runtime" are special
# developmentMode-only messages used in e2e tests to test reconnect handling and
# disabling widgets.
if msg.WhichOneof("type") == "debug_disconnect_websocket":
if config.get_option("global.developmentMode") or config.get_option(
"global.e2eTest"
):
self.close()
else:
_LOGGER.warning(
"Client tried to disconnect websocket when not in development mode or e2e testing."
)
elif msg.WhichOneof("type") == "debug_shutdown_runtime":
if config.get_option("global.developmentMode") or config.get_option(
"global.e2eTest"
):
self._runtime.stop()
else:
_LOGGER.warning(
"Client tried to shut down runtime when not in development mode or e2e testing."
)
else:
# AppSession handles all other BackMsg types.
self._runtime.handle_backmsg(self._session_id, msg)

View File

@@ -0,0 +1,116 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import mimetypes
import os
from typing import TYPE_CHECKING, Final
import tornado.web
import streamlit.web.server.routes
from streamlit.logger import get_logger
if TYPE_CHECKING:
from streamlit.components.types.base_component_registry import BaseComponentRegistry
_LOGGER: Final = get_logger(__name__)
class ComponentRequestHandler(tornado.web.RequestHandler):
def initialize(self, registry: BaseComponentRegistry):
self._registry = registry
def get(self, path: str) -> None:
parts = path.split("/")
component_name = parts[0]
component_root = self._registry.get_component_path(component_name)
if component_root is None:
self.write("not found")
self.set_status(404)
return
# follow symlinks to get an accurate normalized path
component_root = os.path.realpath(component_root)
filename = "/".join(parts[1:])
abspath = os.path.normpath(os.path.join(component_root, filename))
# Do NOT expose anything outside of the component root.
if os.path.commonpath([component_root, abspath]) != component_root:
self.write("forbidden")
self.set_status(403)
return
try:
with open(abspath, "rb") as file:
contents = file.read()
except OSError as e:
_LOGGER.error(
"ComponentRequestHandler: GET %s read error", abspath, exc_info=e
)
self.write("read error")
self.set_status(404)
return
self.write(contents)
self.set_header("Content-Type", self.get_content_type(abspath))
self.set_extra_headers(path)
def set_extra_headers(self, path: str) -> None:
"""Disable cache for HTML files.
Other assets like JS and CSS are suffixed with their hash, so they can
be cached indefinitely.
"""
is_index_url = len(path) == 0
if is_index_url or path.endswith(".html"):
self.set_header("Cache-Control", "no-cache")
else:
self.set_header("Cache-Control", "public")
def set_default_headers(self) -> None:
if streamlit.web.server.routes.allow_cross_origin_requests():
self.set_header("Access-Control-Allow-Origin", "*")
def options(self) -> None:
"""/OPTIONS handler for preflight CORS checks."""
self.set_status(204)
self.finish()
@staticmethod
def get_content_type(abspath: str) -> str:
"""Returns the ``Content-Type`` header to be used for this request.
From tornado.web.StaticFileHandler.
"""
mime_type, encoding = mimetypes.guess_type(abspath)
# per RFC 6713, use the appropriate type for a gzip compressed file
if encoding == "gzip":
return "application/gzip"
# As of 2015-07-21 there is no bzip2 encoding defined at
# http://www.iana.org/assignments/media-types/media-types.xhtml
# So for that (and any other encoding), use octet-stream.
elif encoding is not None:
return "application/octet-stream"
elif mime_type is not None:
return mime_type
# if mime_type not detected, use application/octet-stream
else:
return "application/octet-stream"
@staticmethod
def get_url(file_id: str) -> str:
"""Return the URL for a component file with the given ID."""
return f"components/{file_id}"

View File

@@ -0,0 +1,141 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from urllib.parse import quote
import tornado.web
from streamlit.logger import get_logger
from streamlit.runtime.media_file_storage import MediaFileKind, MediaFileStorageError
from streamlit.runtime.memory_media_file_storage import (
MemoryMediaFileStorage,
get_extension_for_mimetype,
)
from streamlit.web.server import allow_cross_origin_requests
_LOGGER = get_logger(__name__)
class MediaFileHandler(tornado.web.StaticFileHandler):
_storage: MemoryMediaFileStorage
@classmethod
def initialize_storage(cls, storage: MemoryMediaFileStorage) -> None:
"""Set the MemoryMediaFileStorage object used by instances of this
handler. Must be called on server startup.
"""
# This is a class method, rather than an instance method, because
# `get_content()` is a class method and needs to access the storage
# instance.
cls._storage = storage
def set_default_headers(self) -> None:
if allow_cross_origin_requests():
self.set_header("Access-Control-Allow-Origin", "*")
def set_extra_headers(self, path: str) -> None:
"""Add Content-Disposition header for downloadable files.
Set header value to "attachment" indicating that file should be saved
locally instead of displaying inline in browser.
We also set filename to specify the filename for downloaded files.
Used for serving downloadable files, like files stored via the
`st.download_button` widget.
"""
media_file = self._storage.get_file(path)
if media_file and media_file.kind == MediaFileKind.DOWNLOADABLE:
filename = media_file.filename
if not filename:
filename = f"streamlit_download{get_extension_for_mimetype(media_file.mimetype)}"
try:
# Check that the value can be encoded in latin1. Latin1 is
# the default encoding for headers.
filename.encode("latin1")
file_expr = f'filename="{filename}"'
except UnicodeEncodeError:
# RFC5987 syntax.
# See: https://datatracker.ietf.org/doc/html/rfc5987
file_expr = f"filename*=utf-8''{quote(filename)}"
self.set_header("Content-Disposition", f"attachment; {file_expr}")
# Overriding StaticFileHandler to use the MediaFileManager
#
# From the Tornado docs:
# To replace all interaction with the filesystem (e.g. to serve
# static content from a database), override `get_content`,
# `get_content_size`, `get_modified_time`, `get_absolute_path`, and
# `validate_absolute_path`.
def validate_absolute_path(self, root: str, absolute_path: str) -> str:
try:
self._storage.get_file(absolute_path)
except MediaFileStorageError:
_LOGGER.error("MediaFileHandler: Missing file %s", absolute_path)
raise tornado.web.HTTPError(404, "not found")
return absolute_path
def get_content_size(self) -> int:
abspath = self.absolute_path
if abspath is None:
return 0
media_file = self._storage.get_file(abspath)
return media_file.content_size
def get_modified_time(self) -> None:
# We do not track last modified time, but this can be improved to
# allow caching among files in the MediaFileManager
return None
@classmethod
def get_absolute_path(cls, root: str, path: str) -> str:
# All files are stored in memory, so the absolute path is just the
# path itself. In the MediaFileHandler, it's just the filename
return path
@classmethod
def get_content(
cls, abspath: str, start: int | None = None, end: int | None = None
):
_LOGGER.debug("MediaFileHandler: GET %s", abspath)
try:
# abspath is the hash as used `get_absolute_path`
media_file = cls._storage.get_file(abspath)
except Exception:
_LOGGER.error("MediaFileHandler: Missing file %s", abspath)
return None
_LOGGER.debug(
"MediaFileHandler: Sending %s file %s", media_file.mimetype, abspath
)
# If there is no start and end, just return the full content
if start is None and end is None:
return media_file.content
if start is None:
start = 0
if end is None:
end = len(media_file.content)
# content is bytes that work just by slicing supplied by start and end
return media_file.content[start:end]

View File

@@ -0,0 +1,176 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
from typing import Any
from urllib.parse import urlparse
import tornado.web
from streamlit.auth_util import (
AuthCache,
decode_provider_token,
generate_default_provider_section,
get_secrets_auth_section,
)
from streamlit.errors import StreamlitAuthError
from streamlit.url_util import make_url_path
from streamlit.web.server.oidc_mixin import TornadoOAuth, TornadoOAuth2App
from streamlit.web.server.server_util import AUTH_COOKIE_NAME
auth_cache = AuthCache()
def create_oauth_client(provider: str) -> tuple[TornadoOAuth2App, str]:
"""Create an OAuth client for the given provider based on secrets.toml configuration."""
auth_section = get_secrets_auth_section()
if auth_section:
redirect_uri = auth_section.get("redirect_uri", None)
config = auth_section.to_dict()
else:
config = {}
redirect_uri = "/"
provider_section = config.setdefault(provider, {})
if not provider_section and provider == "default":
provider_section = generate_default_provider_section(auth_section)
config["default"] = provider_section
provider_client_kwargs = provider_section.setdefault("client_kwargs", {})
if "scope" not in provider_client_kwargs:
provider_client_kwargs["scope"] = "openid email profile"
if "prompt" not in provider_client_kwargs:
provider_client_kwargs["prompt"] = "select_account"
oauth = TornadoOAuth(config, cache=auth_cache)
oauth.register(provider)
return oauth.create_client(provider), redirect_uri
class AuthHandlerMixin(tornado.web.RequestHandler):
"""Mixin for handling auth cookies. Added for compatibility with Tornado < 6.3.0."""
def initialize(self, base_url: str) -> None:
self.base_url = base_url
def redirect_to_base(self) -> None:
self.redirect(make_url_path(self.base_url, "/"))
def set_auth_cookie(self, user_info: dict[str, Any]) -> None:
serialized_cookie_value = json.dumps(user_info)
try:
# We don't specify Tornado secure flag here because it leads to missing cookie on Safari.
# The OIDC flow should work only on secure context anyway (localhost or HTTPS),
# so specifying the secure flag here will not add anything in terms of security.
self.set_signed_cookie(
AUTH_COOKIE_NAME,
serialized_cookie_value,
httpOnly=True,
)
except AttributeError:
self.set_secure_cookie(
AUTH_COOKIE_NAME,
serialized_cookie_value,
httponly=True,
)
def clear_auth_cookie(self) -> None:
self.clear_cookie(AUTH_COOKIE_NAME)
class AuthLoginHandler(AuthHandlerMixin, tornado.web.RequestHandler):
async def get(self):
"""Redirect to the OAuth provider login page."""
provider = self._parse_provider_token()
if provider is None:
self.redirect_to_base()
return
client, redirect_uri = create_oauth_client(provider)
try:
client.authorize_redirect(self, redirect_uri)
except Exception as e:
self.send_error(400, reason=str(e))
def _parse_provider_token(self) -> str | None:
provider_token = self.get_argument("provider", None)
try:
if provider_token is None:
raise StreamlitAuthError("Missing provider token")
payload = decode_provider_token(provider_token)
except StreamlitAuthError:
return None
return payload["provider"]
class AuthLogoutHandler(AuthHandlerMixin, tornado.web.RequestHandler):
def get(self):
self.clear_auth_cookie()
self.redirect_to_base()
class AuthCallbackHandler(AuthHandlerMixin, tornado.web.RequestHandler):
async def get(self):
provider = self._get_provider_by_state()
origin = self._get_origin_from_secrets()
if origin is None:
self.redirect_to_base()
return
error = self.get_argument("error", None)
if error:
self.redirect_to_base()
return
if provider is None:
self.redirect_to_base()
return
client, _ = create_oauth_client(provider)
token = client.authorize_access_token(self)
user = token.get("userinfo")
cookie_value = dict(user, origin=origin, is_logged_in=True)
if user:
self.set_auth_cookie(cookie_value)
self.redirect_to_base()
def _get_provider_by_state(self) -> str | None:
state_code_from_url = self.get_argument("state")
current_cache_keys = list(auth_cache.get_dict().keys())
state_provider_mapping = {}
for key in current_cache_keys:
_, _, recorded_provider, code = key.split("_")
state_provider_mapping[code] = recorded_provider
provider: str | None = state_provider_mapping.get(state_code_from_url, None)
return provider
def _get_origin_from_secrets(self) -> str | None:
redirect_uri = None
auth_section = get_secrets_auth_section()
if auth_section:
redirect_uri = auth_section.get("redirect_uri", None)
if not redirect_uri:
return None
redirect_uri_parsed = urlparse(redirect_uri)
origin_from_redirect_uri: str = (
redirect_uri_parsed.scheme + "://" + redirect_uri_parsed.netloc
)
return origin_from_redirect_uri

View File

@@ -0,0 +1,108 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tornado.web
from authlib.integrations.base_client import ( # type: ignore[import-untyped]
BaseApp,
BaseOAuth,
OAuth2Mixin,
OAuthError,
OpenIDMixin,
)
from authlib.integrations.requests_client import ( # type: ignore[import-untyped]
OAuth2Session,
)
from streamlit.web.server.authlib_tornado_integration import TornadoIntegration
class TornadoOAuth2App(OAuth2Mixin, OpenIDMixin, BaseApp): # type: ignore[misc]
client_cls = OAuth2Session
def load_server_metadata(self):
"""We enforce S256 code challenge method if it is supported by the server."""
result = super().load_server_metadata()
if "S256" in result.get("code_challenge_methods_supported", []):
self.client_kwargs["code_challenge_method"] = "S256"
return result
def authorize_redirect(
self, request_handler: tornado.web.RequestHandler, redirect_uri=None, **kwargs
):
"""Create a HTTP Redirect for Authorization Endpoint.
:param request_handler: HTTP request instance from Tornado.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: A HTTP redirect response.
"""
auth_context = self.create_authorization_url(redirect_uri, **kwargs)
self._save_authorize_data(redirect_uri=redirect_uri, **auth_context)
request_handler.redirect(auth_context["url"], status=302)
def authorize_access_token(
self, request_handler: tornado.web.RequestHandler, **kwargs
):
"""
:param request_handler: HTTP request instance from Tornado.
:return: A token dict.
"""
error = request_handler.get_argument("error", None)
if error:
description = request_handler.get_argument("error_description", None)
raise OAuthError(error=error, description=description)
params = {
"code": request_handler.get_argument("code"),
"state": request_handler.get_argument("state"),
}
assert self.framework.cache is not None
session = None
claims_options = kwargs.pop("claims_options", None)
state_data = self.framework.get_state_data(session, params.get("state"))
self.framework.clear_state_data(session, params.get("state"))
params = self._format_state_params(state_data, params)
token = self.fetch_access_token(**params, **kwargs)
if "id_token" in token and "nonce" in state_data:
userinfo = self.parse_id_token(
token, nonce=state_data["nonce"], claims_options=claims_options
)
token = {**token, "userinfo": userinfo}
return token
def _save_authorize_data(self, **kwargs):
"""Authlib underlying uses the concept of "session" to store state data.
In Tornado, we don't have a session, so we use the framework's cache option.
"""
state = kwargs.pop("state", None)
if state:
assert self.framework.cache is not None
session = None
self.framework.set_state_data(session, state, kwargs)
else:
raise RuntimeError("Missing state value")
class TornadoOAuth(BaseOAuth): # type: ignore[misc]
oauth2_client_cls = TornadoOAuth2App
framework_integration_cls = TornadoIntegration
def __init__(self, config=None, cache=None, fetch_token=None, update_token=None):
super().__init__(
cache=cache, fetch_token=fetch_token, update_token=update_token
)
self.config = config

View File

@@ -0,0 +1,296 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
from typing import TYPE_CHECKING, Final
import tornado.web
from streamlit import config, file_util
from streamlit.logger import get_logger
from streamlit.runtime.runtime_util import serialize_forward_msg
from streamlit.web.server.server_util import (
emit_endpoint_deprecation_notice,
is_xsrf_enabled,
)
if TYPE_CHECKING:
from collections.abc import Sequence
_LOGGER: Final = get_logger(__name__)
def allow_cross_origin_requests() -> bool:
"""True if cross-origin requests are allowed.
We only allow cross-origin requests when CORS protection has been disabled
with server.enableCORS=False or if using the Node server. When using the
Node server, we have a dev and prod port, which count as two origins.
"""
return not config.get_option("server.enableCORS") or config.get_option(
"global.developmentMode"
)
class StaticFileHandler(tornado.web.StaticFileHandler):
def initialize(
self,
path: str,
default_filename: str | None = None,
reserved_paths: Sequence[str] = (),
):
self._reserved_paths = reserved_paths
super().initialize(path, default_filename)
def set_extra_headers(self, path: str) -> None:
"""Disable cache for HTML files.
Other assets like JS and CSS are suffixed with their hash, so they can
be cached indefinitely.
"""
is_index_url = len(path) == 0
if is_index_url or path.endswith(".html"):
self.set_header("Cache-Control", "no-cache")
else:
self.set_header("Cache-Control", "public")
def validate_absolute_path(self, root: str, absolute_path: str) -> str | None:
try:
return super().validate_absolute_path(root, absolute_path)
except tornado.web.HTTPError as e:
# If the file is not found, and there are no reserved paths,
# we try to serve the default file and allow the frontend to handle the issue.
if e.status_code == 404:
url_path = self.path
# self.path is OS specific file path, we convert it to a URL path
# for checking it against reserved paths.
if os.path.sep != "/":
url_path = url_path.replace(os.path.sep, "/")
if any(url_path.endswith(x) for x in self._reserved_paths):
raise e
self.path = self.parse_url_path(self.default_filename or "index.html")
absolute_path = self.get_absolute_path(self.root, self.path)
return super().validate_absolute_path(root, absolute_path)
raise e
def write_error(self, status_code: int, **kwargs) -> None:
if status_code == 404:
index_file = os.path.join(file_util.get_static_dir(), "index.html")
self.render(index_file)
else:
super().write_error(status_code, **kwargs)
class AddSlashHandler(tornado.web.RequestHandler):
@tornado.web.addslash
def get(self):
pass
class RemoveSlashHandler(tornado.web.RequestHandler):
@tornado.web.removeslash
def get(self):
pass
class _SpecialRequestHandler(tornado.web.RequestHandler):
"""Superclass for "special" endpoints, like /healthz."""
def set_default_headers(self):
self.set_header("Cache-Control", "no-cache")
if allow_cross_origin_requests():
self.set_header("Access-Control-Allow-Origin", "*")
def options(self):
"""/OPTIONS handler for preflight CORS checks.
When a browser is making a CORS request, it may sometimes first
send an OPTIONS request, to check whether the server understands the
CORS protocol. This is optional, and doesn't happen for every request
or in every browser. If an OPTIONS request does get sent, and is not
then handled by the server, the browser will fail the underlying
request.
The proper way to handle this is to send a 204 response ("no content")
with the CORS headers attached. (These headers are automatically added
to every outgoing response, including OPTIONS responses,
via set_default_headers().)
See https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
"""
self.set_status(204)
self.finish()
class HealthHandler(_SpecialRequestHandler):
def initialize(self, callback):
"""Initialize the handler.
Parameters
----------
callback : callable
A function that returns True if the server is healthy
"""
self._callback = callback
async def get(self):
await self.handle_request()
# Some monitoring services only support the HTTP HEAD method for requests to
# healthcheck endpoints, so we support HEAD as well to play nicely with them.
async def head(self):
await self.handle_request()
async def handle_request(self):
if self.request.uri and "_stcore/" not in self.request.uri:
new_path = (
"/_stcore/script-health-check"
if "script-health-check" in self.request.uri
else "/_stcore/health"
)
emit_endpoint_deprecation_notice(self, new_path=new_path)
ok, msg = await self._callback()
if ok:
self.write(msg)
self.set_status(200)
# Tornado will set the _streamlit_xsrf cookie automatically for the page on
# request for the document. However, if the server is reset and
# server.enableXsrfProtection is updated, the browser does not reload the document.
# Manually setting the cookie on /healthz since it is pinged when the
# browser is disconnected from the server.
if is_xsrf_enabled():
cookie_kwargs = self.settings.get("xsrf_cookie_kwargs", {})
self.set_cookie(
self.settings.get("xsrf_cookie_name", "_streamlit_xsrf"),
self.xsrf_token,
**cookie_kwargs,
)
else:
# 503 = SERVICE_UNAVAILABLE
self.set_status(503)
self.write(msg)
_DEFAULT_ALLOWED_MESSAGE_ORIGINS = [
# Community-cloud related domains.
# We can remove these in the future if community cloud
# provides those domains via the host-config endpoint.
"https://devel.streamlit.test",
"https://*.streamlit.apptest",
"https://*.streamlitapp.test",
"https://*.streamlitapp.com",
"https://share.streamlit.io",
"https://share-demo.streamlit.io",
"https://share-head.streamlit.io",
"https://share-staging.streamlit.io",
"https://*.demo.streamlit.run",
"https://*.head.streamlit.run",
"https://*.staging.streamlit.run",
"https://*.streamlit.run",
"https://*.demo.streamlit.app",
"https://*.head.streamlit.app",
"https://*.staging.streamlit.app",
"https://*.streamlit.app",
]
class HostConfigHandler(_SpecialRequestHandler):
def initialize(self):
# Make a copy of the allowedOrigins list, since we might modify it later:
self._allowed_origins = _DEFAULT_ALLOWED_MESSAGE_ORIGINS.copy()
if (
config.get_option("global.developmentMode")
and "http://localhost" not in self._allowed_origins
):
# Allow messages from localhost in dev mode for testing of host <-> guest communication
self._allowed_origins.append("http://localhost")
async def get(self) -> None:
self.write(
{
"allowedOrigins": self._allowed_origins,
"useExternalAuthToken": False,
# Default host configuration settings.
"enableCustomParentMessages": False,
"enforceDownloadInNewTab": False,
"metricsUrl": "",
"blockErrorDialogs": False,
}
)
self.set_status(200)
class MessageCacheHandler(tornado.web.RequestHandler):
"""Returns ForwardMsgs from our MessageCache."""
def initialize(self, cache):
"""Initializes the handler.
Parameters
----------
cache : MessageCache
"""
self._cache = cache
def set_default_headers(self):
if allow_cross_origin_requests():
self.set_header("Access-Control-Allow-Origin", "*")
def get(self):
msg_hash = self.get_argument("hash", None)
if not config.get_option("global.storeCachedForwardMessagesInMemory"):
# We use rare status code here, to distinguish between normal 404s.
self.set_status(418)
self.finish()
return
if msg_hash is None:
# Hash is missing! This is a malformed request.
_LOGGER.error(
"HTTP request for cached message is missing the hash attribute."
)
self.set_status(404)
raise tornado.web.Finish()
message = self._cache.get_message(msg_hash)
if message is None:
# Message not in our cache.
_LOGGER.error(
"HTTP request for cached message could not be fulfilled. "
"No such message"
)
self.set_status(404)
raise tornado.web.Finish()
_LOGGER.debug("MessageCache HIT")
msg_str = serialize_forward_msg(message)
self.set_header("Content-Type", "application/octet-stream")
self.write(msg_str)
self.set_status(200)
def options(self):
"""/OPTIONS handler for preflight CORS checks."""
self.set_status(204)
self.finish()

View File

@@ -0,0 +1,479 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import errno
import logging
import mimetypes
import os
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any, Final
import tornado.concurrent
import tornado.locks
import tornado.netutil
import tornado.web
import tornado.websocket
from tornado.httpserver import HTTPServer
from streamlit import cli_util, config, file_util, util
from streamlit.auth_util import is_authlib_installed
from streamlit.config_option import ConfigOption
from streamlit.logger import get_logger
from streamlit.runtime import Runtime, RuntimeConfig, RuntimeState
from streamlit.runtime.memory_media_file_storage import MemoryMediaFileStorage
from streamlit.runtime.memory_session_storage import MemorySessionStorage
from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
from streamlit.runtime.runtime_util import get_max_message_size_bytes
from streamlit.web.cache_storage_manager_config import (
create_default_cache_storage_manager,
)
from streamlit.web.server.app_static_file_handler import AppStaticFileHandler
from streamlit.web.server.browser_websocket_handler import BrowserWebSocketHandler
from streamlit.web.server.component_request_handler import ComponentRequestHandler
from streamlit.web.server.media_file_handler import MediaFileHandler
from streamlit.web.server.routes import (
AddSlashHandler,
HealthHandler,
HostConfigHandler,
MessageCacheHandler,
RemoveSlashHandler,
StaticFileHandler,
)
from streamlit.web.server.server_util import (
DEVELOPMENT_PORT,
get_cookie_secret,
is_xsrf_enabled,
make_url_path_regex,
)
from streamlit.web.server.stats_request_handler import StatsRequestHandler
from streamlit.web.server.upload_file_request_handler import UploadFileRequestHandler
if TYPE_CHECKING:
from collections.abc import Awaitable
from ssl import SSLContext
_LOGGER: Final = get_logger(__name__)
TORNADO_SETTINGS = {
# Gzip HTTP responses.
"compress_response": True,
# Ping every 1s to keep WS alive.
# 2021.06.22: this value was previously 20s, and was causing
# connection instability for a small number of users. This smaller
# ping_interval fixes that instability.
# https://github.com/streamlit/streamlit/issues/3196
"websocket_ping_interval": 1,
# If we don't get a ping response within 30s, the connection
# is timed out.
"websocket_ping_timeout": 30,
"xsrf_cookie_name": "_streamlit_xsrf",
}
# When server.port is not available it will look for the next available port
# up to MAX_PORT_SEARCH_RETRIES.
MAX_PORT_SEARCH_RETRIES: Final = 100
# When server.address starts with this prefix, the server will bind
# to an unix socket.
UNIX_SOCKET_PREFIX: Final = "unix://"
MEDIA_ENDPOINT: Final = "/media"
UPLOAD_FILE_ENDPOINT: Final = "/_stcore/upload_file"
STREAM_ENDPOINT: Final = r"_stcore/stream"
METRIC_ENDPOINT: Final = r"(?:st-metrics|_stcore/metrics)"
MESSAGE_ENDPOINT: Final = r"_stcore/message"
NEW_HEALTH_ENDPOINT: Final = "_stcore/health"
HEALTH_ENDPOINT: Final = rf"(?:healthz|{NEW_HEALTH_ENDPOINT})"
HOST_CONFIG_ENDPOINT: Final = r"_stcore/host-config"
SCRIPT_HEALTH_CHECK_ENDPOINT: Final = (
r"(?:script-health-check|_stcore/script-health-check)"
)
OAUTH2_CALLBACK_ENDPOINT: Final = "/oauth2callback"
AUTH_LOGIN_ENDPOINT: Final = "/auth/login"
AUTH_LOGOUT_ENDPOINT: Final = "/auth/logout"
class RetriesExceeded(Exception):
pass
def server_port_is_manually_set() -> bool:
return config.is_manually_set("server.port")
def server_address_is_unix_socket() -> bool:
address = config.get_option("server.address")
return address is not None and address.startswith(UNIX_SOCKET_PREFIX)
def start_listening(app: tornado.web.Application) -> None:
"""Makes the server start listening at the configured port.
In case the port is already taken it tries listening to the next available
port. It will error after MAX_PORT_SEARCH_RETRIES attempts.
"""
cert_file = config.get_option("server.sslCertFile")
key_file = config.get_option("server.sslKeyFile")
ssl_options = _get_ssl_options(cert_file, key_file)
http_server = HTTPServer(
app,
max_buffer_size=config.get_option("server.maxUploadSize") * 1024 * 1024,
ssl_options=ssl_options,
)
if server_address_is_unix_socket():
start_listening_unix_socket(http_server)
else:
start_listening_tcp_socket(http_server)
def _get_ssl_options(cert_file: str | None, key_file: str | None) -> SSLContext | None:
if bool(cert_file) != bool(key_file):
_LOGGER.error(
"Options 'server.sslCertFile' and 'server.sslKeyFile' must "
"be set together. Set missing options or delete existing options."
)
sys.exit(1)
if cert_file and key_file:
# ssl_ctx.load_cert_chain raise exception as below, but it is not
# sufficiently user-friendly
# FileNotFoundError: [Errno 2] No such file or directory
if not Path(cert_file).exists():
_LOGGER.error("Cert file '%s' does not exist.", cert_file)
sys.exit(1)
if not Path(key_file).exists():
_LOGGER.error("Key file '%s' does not exist.", key_file)
sys.exit(1)
import ssl
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
# When the SSL certificate fails to load, an exception is raised as below,
# but it is not sufficiently user-friendly.
# ssl.SSLError: [SSL] PEM lib (_ssl.c:4067)
try:
ssl_ctx.load_cert_chain(cert_file, key_file)
except ssl.SSLError:
_LOGGER.error(
"Failed to load SSL certificate. Make sure "
"cert file '%s' and key file '%s' are correct.",
cert_file,
key_file,
)
sys.exit(1)
return ssl_ctx
return None
def start_listening_unix_socket(http_server: HTTPServer) -> None:
address = config.get_option("server.address")
file_name = os.path.expanduser(address[len(UNIX_SOCKET_PREFIX) :])
unix_socket = tornado.netutil.bind_unix_socket(file_name)
http_server.add_socket(unix_socket)
def start_listening_tcp_socket(http_server: HTTPServer) -> None:
call_count = 0
port = None
while call_count < MAX_PORT_SEARCH_RETRIES:
address = config.get_option("server.address")
port = config.get_option("server.port")
if int(port) == DEVELOPMENT_PORT:
_LOGGER.warning(
"Port %s is reserved for internal development. "
"It is strongly recommended to select an alternative port "
"for `server.port`.",
DEVELOPMENT_PORT,
)
try:
http_server.listen(port, address)
break # It worked! So let's break out of the loop.
except OSError as e:
if e.errno == errno.EADDRINUSE:
if server_port_is_manually_set():
_LOGGER.error("Port %s is already in use", port)
sys.exit(1)
else:
_LOGGER.debug(
"Port %s already in use, trying to use the next one.", port
)
port += 1
# Don't use the development port here:
if port == DEVELOPMENT_PORT:
port += 1
config.set_option(
"server.port", port, ConfigOption.STREAMLIT_DEFINITION
)
call_count += 1
else:
raise
if call_count >= MAX_PORT_SEARCH_RETRIES:
raise RetriesExceeded(
f"Cannot start Streamlit server. Port {port} is already in use, and "
f"Streamlit was unable to find a free port after {MAX_PORT_SEARCH_RETRIES} attempts.",
)
class Server:
def __init__(self, main_script_path: str, is_hello: bool):
"""Create the server. It won't be started yet."""
_set_tornado_log_levels()
self.initialize_mimetypes()
self._main_script_path = main_script_path
# Initialize MediaFileStorage and its associated endpoint
media_file_storage = MemoryMediaFileStorage(MEDIA_ENDPOINT)
MediaFileHandler.initialize_storage(media_file_storage)
uploaded_file_mgr = MemoryUploadedFileManager(UPLOAD_FILE_ENDPOINT)
self._runtime = Runtime(
RuntimeConfig(
script_path=main_script_path,
command_line=None,
media_file_storage=media_file_storage,
uploaded_file_manager=uploaded_file_mgr,
cache_storage_manager=create_default_cache_storage_manager(),
is_hello=is_hello,
session_storage=MemorySessionStorage(
ttl_seconds=config.get_option("server.disconnectedSessionTTL")
),
),
)
self._runtime.stats_mgr.register_provider(media_file_storage)
@classmethod
def initialize_mimetypes(cls) -> None:
"""Ensures that common mime-types are robust against system misconfiguration."""
mimetypes.add_type("text/html", ".html")
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
mimetypes.add_type("image/webp", ".webp")
def __repr__(self) -> str:
return util.repr_(self)
@property
def main_script_path(self) -> str:
return self._main_script_path
async def start(self) -> None:
"""Start the server.
When this returns, Streamlit is ready to accept new sessions.
"""
_LOGGER.debug("Starting server...")
app = self._create_app()
start_listening(app)
port = config.get_option("server.port")
_LOGGER.debug("Server started on port %s", port)
await self._runtime.start()
@property
def stopped(self) -> Awaitable[None]:
"""A Future that completes when the Server's run loop has exited."""
return self._runtime.stopped
def _create_app(self) -> tornado.web.Application:
"""Create our tornado web app."""
base = config.get_option("server.baseUrlPath")
routes: list[Any] = [
(
make_url_path_regex(base, STREAM_ENDPOINT),
BrowserWebSocketHandler,
{"runtime": self._runtime},
),
(
make_url_path_regex(base, HEALTH_ENDPOINT),
HealthHandler,
{"callback": lambda: self._runtime.is_ready_for_browser_connection},
),
(
make_url_path_regex(base, MESSAGE_ENDPOINT),
MessageCacheHandler,
{"cache": self._runtime.message_cache},
),
(
make_url_path_regex(base, METRIC_ENDPOINT),
StatsRequestHandler,
{"stats_manager": self._runtime.stats_mgr},
),
(
make_url_path_regex(base, HOST_CONFIG_ENDPOINT),
HostConfigHandler,
),
(
make_url_path_regex(
base,
rf"{UPLOAD_FILE_ENDPOINT}/(?P<session_id>[^/]+)/(?P<file_id>[^/]+)",
),
UploadFileRequestHandler,
{
"file_mgr": self._runtime.uploaded_file_mgr,
"is_active_session": self._runtime.is_active_session,
},
),
(
make_url_path_regex(base, f"{MEDIA_ENDPOINT}/(.*)"),
MediaFileHandler,
{"path": ""},
),
(
make_url_path_regex(base, "component/(.*)"),
ComponentRequestHandler,
{"registry": self._runtime.component_registry},
),
]
if config.get_option("server.scriptHealthCheckEnabled"):
routes.extend(
[
(
make_url_path_regex(base, SCRIPT_HEALTH_CHECK_ENDPOINT),
HealthHandler,
{
"callback": lambda: self._runtime.does_script_run_without_error()
},
)
]
)
if config.get_option("server.enableStaticServing"):
routes.extend(
[
(
make_url_path_regex(base, "app/static/(.*)"),
AppStaticFileHandler,
{"path": file_util.get_app_static_dir(self.main_script_path)},
),
]
)
if is_authlib_installed():
from streamlit.web.server.oauth_authlib_routes import (
AuthCallbackHandler,
AuthLoginHandler,
AuthLogoutHandler,
)
routes.extend(
[
(
make_url_path_regex(base, OAUTH2_CALLBACK_ENDPOINT),
AuthCallbackHandler,
{"base_url": base},
),
(
make_url_path_regex(base, AUTH_LOGIN_ENDPOINT),
AuthLoginHandler,
{"base_url": base},
),
(
make_url_path_regex(base, AUTH_LOGOUT_ENDPOINT),
AuthLogoutHandler,
{"base_url": base},
),
]
)
if config.get_option("global.developmentMode"):
_LOGGER.debug("Serving static content from the Node dev server")
else:
static_path = file_util.get_static_dir()
_LOGGER.debug("Serving static content from %s", static_path)
routes.extend(
[
(
# We want to remove paths with a trailing slash, but if the path
# starts with a double slash //, the redirect will point
# the browser to the wrong host.
make_url_path_regex(
base, "(?!/)(.*)", trailing_slash="required"
),
RemoveSlashHandler,
),
(
make_url_path_regex(base, "(.*)"),
StaticFileHandler,
{
"path": "%s/" % static_path,
"default_filename": "index.html",
"reserved_paths": [
# These paths are required for identifying
# the base url path.
NEW_HEALTH_ENDPOINT,
HOST_CONFIG_ENDPOINT,
],
},
),
(
make_url_path_regex(base, trailing_slash="prohibited"),
AddSlashHandler,
),
]
)
return tornado.web.Application(
routes,
cookie_secret=get_cookie_secret(),
xsrf_cookies=is_xsrf_enabled(),
# Set the websocket message size. The default value is too low.
websocket_max_message_size=get_max_message_size_bytes(),
**TORNADO_SETTINGS, # type: ignore[arg-type]
)
@property
def browser_is_connected(self) -> bool:
return self._runtime.state == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED
@property
def is_running_hello(self) -> bool:
from streamlit.hello import streamlit_app
return self._main_script_path == streamlit_app.__file__
def stop(self) -> None:
cli_util.print_to_cli(" Stopping...", fg="blue")
self._runtime.stop()
def _set_tornado_log_levels() -> None:
if not config.get_option("global.developmentMode"):
# Hide logs unless they're super important.
# Example of stuff we don't care about: 404 about .js.map files.
logging.getLogger("tornado.access").setLevel(logging.ERROR)
logging.getLogger("tornado.application").setLevel(logging.ERROR)
logging.getLogger("tornado.general").setLevel(logging.ERROR)

View File

@@ -0,0 +1,159 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Server related utility functions."""
from __future__ import annotations
from typing import TYPE_CHECKING, Final, Literal
from urllib.parse import urljoin
from streamlit import config, net_util, url_util
from streamlit.runtime.secrets import secrets_singleton
if TYPE_CHECKING:
from tornado.web import RequestHandler
# The port reserved for internal development.
DEVELOPMENT_PORT: Final = 3000
AUTH_COOKIE_NAME: Final = "_streamlit_user"
def is_url_from_allowed_origins(url: str) -> bool:
"""Return True if URL is from allowed origins (for CORS purpose).
Allowed origins:
1. localhost
2. The internal and external IP addresses of the machine where this
function was called from.
If `server.enableCORS` is False, this allows all origins.
"""
if not config.get_option("server.enableCORS"):
# Allow everything when CORS is disabled.
return True
hostname = url_util.get_hostname(url)
allowed_domains = [ # List[Union[str, Callable[[], Optional[str]]]]
# Check localhost first.
"localhost",
"0.0.0.0",
"127.0.0.1",
# Try to avoid making unnecessary HTTP requests by checking if the user
# manually specified a server address.
_get_server_address_if_manually_set,
# Then try the options that depend on HTTP requests or opening sockets.
net_util.get_internal_ip,
net_util.get_external_ip,
]
for allowed_domain in allowed_domains:
if callable(allowed_domain):
allowed_domain = allowed_domain()
if allowed_domain is None:
continue
if hostname == allowed_domain:
return True
return False
def get_cookie_secret() -> str:
"""Get the cookie secret.
If the user has not set a cookie secret, we generate a random one.
"""
cookie_secret: str = config.get_option("server.cookieSecret")
if secrets_singleton.load_if_toml_exists():
auth_section = secrets_singleton.get("auth")
if auth_section:
cookie_secret = auth_section.get("cookie_secret", cookie_secret)
return cookie_secret
def is_xsrf_enabled():
csrf_enabled = config.get_option("server.enableXsrfProtection")
if not csrf_enabled and secrets_singleton.load_if_toml_exists():
auth_section = secrets_singleton.get("auth", None)
csrf_enabled = csrf_enabled or auth_section is not None
return csrf_enabled
def _get_server_address_if_manually_set() -> str | None:
if config.is_manually_set("browser.serverAddress"):
return url_util.get_hostname(config.get_option("browser.serverAddress"))
return None
def make_url_path_regex(
*path, trailing_slash: Literal["optional", "required", "prohibited"] = "optional"
) -> str:
"""Get a regex of the form ^/foo/bar/baz/?$ for a path (foo, bar, baz)."""
path = [x.strip("/") for x in path if x] # Filter out falsely components.
path_format = r"^/%s$"
if trailing_slash == "optional":
path_format = r"^/%s/?$"
elif trailing_slash == "required":
path_format = r"^/%s/$"
return path_format % "/".join(path)
def get_url(host_ip: str) -> str:
"""Get the URL for any app served at the given host_ip.
Parameters
----------
host_ip : str
The IP address of the machine that is running the Streamlit Server.
Returns
-------
str
The URL.
"""
protocol = "https" if config.get_option("server.sslCertFile") else "http"
port = _get_browser_address_bar_port()
base_path = config.get_option("server.baseUrlPath").strip("/")
if base_path:
base_path = "/" + base_path
host_ip = host_ip.strip("/")
return f"{protocol}://{host_ip}:{port}{base_path}"
def _get_browser_address_bar_port() -> int:
"""Get the app URL that will be shown in the browser's address bar.
That is, this is the port where static assets will be served from. In dev,
this is different from the URL that will be used to connect to the
server-browser websocket.
"""
if config.get_option("global.developmentMode"):
return DEVELOPMENT_PORT
return int(config.get_option("browser.serverPort"))
def emit_endpoint_deprecation_notice(handler: RequestHandler, new_path: str) -> None:
"""Emits the warning about deprecation of HTTP endpoint in the HTTP header."""
handler.set_header("Deprecation", True)
new_url = urljoin(f"{handler.request.protocol}://{handler.request.host}", new_path)
handler.set_header("Link", f'<{new_url}>; rel="alternate"')

View File

@@ -0,0 +1,95 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
import tornado.web
from streamlit.web.server import allow_cross_origin_requests
from streamlit.web.server.server_util import emit_endpoint_deprecation_notice
if TYPE_CHECKING:
from streamlit.proto.openmetrics_data_model_pb2 import MetricSet as MetricSetProto
from streamlit.runtime.stats import CacheStat, StatsManager
class StatsRequestHandler(tornado.web.RequestHandler):
def initialize(self, stats_manager: StatsManager) -> None:
self._manager = stats_manager
def set_default_headers(self):
if allow_cross_origin_requests():
self.set_header("Access-Control-Allow-Origin", "*")
def options(self):
"""/OPTIONS handler for preflight CORS checks."""
self.set_status(204)
self.finish()
def get(self) -> None:
if self.request.uri and "_stcore/" not in self.request.uri:
emit_endpoint_deprecation_notice(self, new_path="/_stcore/metrics")
stats = self._manager.get_stats()
# If the request asked for protobuf output, we return a serialized
# protobuf. Else we return text.
if "application/x-protobuf" in self.request.headers.get_list("Accept"):
self.write(self._stats_to_proto(stats).SerializeToString())
self.set_header("Content-Type", "application/x-protobuf")
self.set_status(200)
else:
self.write(self._stats_to_text(self._manager.get_stats()))
self.set_header("Content-Type", "application/openmetrics-text")
self.set_status(200)
@staticmethod
def _stats_to_text(stats: list[CacheStat]) -> str:
metric_type = "# TYPE cache_memory_bytes gauge"
metric_unit = "# UNIT cache_memory_bytes bytes"
metric_help = "# HELP Total memory consumed by a cache."
openmetrics_eof = "# EOF\n"
# Format: header, stats, EOF
result = [metric_type, metric_unit, metric_help]
result.extend(stat.to_metric_str() for stat in stats)
result.append(openmetrics_eof)
return "\n".join(result)
@staticmethod
def _stats_to_proto(stats: list[CacheStat]) -> MetricSetProto:
# Lazy load the import of this proto message for better performance:
from streamlit.proto.openmetrics_data_model_pb2 import GAUGE
from streamlit.proto.openmetrics_data_model_pb2 import (
MetricSet as MetricSetProto,
)
metric_set = MetricSetProto()
metric_family = metric_set.metric_families.add()
metric_family.name = "cache_memory_bytes"
metric_family.type = GAUGE
metric_family.unit = "bytes"
metric_family.help = "Total memory consumed by a cache."
for stat in stats:
metric_proto = metric_family.metrics.add()
stat.marshall_metric_proto(metric_proto)
metric_set = MetricSetProto()
metric_set.metric_families.append(metric_family)
return metric_set

View File

@@ -0,0 +1,137 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable
import tornado.httputil
import tornado.web
from streamlit import config
from streamlit.runtime.uploaded_file_manager import UploadedFileRec
from streamlit.web.server import routes, server_util
from streamlit.web.server.server_util import is_xsrf_enabled
if TYPE_CHECKING:
from streamlit.runtime.memory_uploaded_file_manager import MemoryUploadedFileManager
class UploadFileRequestHandler(tornado.web.RequestHandler):
"""Implements the POST /upload_file endpoint."""
def initialize(
self,
file_mgr: MemoryUploadedFileManager,
is_active_session: Callable[[str], bool],
):
"""
Parameters
----------
file_mgr : UploadedFileManager
The server's singleton UploadedFileManager. All file uploads
go here.
is_active_session:
A function that returns true if a session_id belongs to an active
session.
"""
self._file_mgr = file_mgr
self._is_active_session = is_active_session
def set_default_headers(self):
self.set_header("Access-Control-Allow-Methods", "PUT, OPTIONS, DELETE")
self.set_header("Access-Control-Allow-Headers", "Content-Type")
if is_xsrf_enabled():
self.set_header(
"Access-Control-Allow-Origin",
server_util.get_url(config.get_option("browser.serverAddress")),
)
self.set_header("Access-Control-Allow-Headers", "X-Xsrftoken, Content-Type")
self.set_header("Vary", "Origin")
self.set_header("Access-Control-Allow-Credentials", "true")
elif routes.allow_cross_origin_requests():
self.set_header("Access-Control-Allow-Origin", "*")
def options(self, **kwargs):
"""/OPTIONS handler for preflight CORS checks.
When a browser is making a CORS request, it may sometimes first
send an OPTIONS request, to check whether the server understands the
CORS protocol. This is optional, and doesn't happen for every request
or in every browser. If an OPTIONS request does get sent, and is not
then handled by the server, the browser will fail the underlying
request.
The proper way to handle this is to send a 204 response ("no content")
with the CORS headers attached. (These headers are automatically added
to every outgoing response, including OPTIONS responses,
via set_default_headers().)
See https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
"""
self.set_status(204)
self.finish()
def put(self, **kwargs):
"""Receive an uploaded file and add it to our UploadedFileManager."""
args: dict[str, list[bytes]] = {}
files: dict[str, list[Any]] = {}
session_id = self.path_kwargs["session_id"]
file_id = self.path_kwargs["file_id"]
tornado.httputil.parse_body_arguments(
content_type=self.request.headers["Content-Type"],
body=self.request.body,
arguments=args,
files=files,
)
try:
if not self._is_active_session(session_id):
raise Exception("Invalid session_id")
except Exception as e:
self.send_error(400, reason=str(e))
return
uploaded_files: list[UploadedFileRec] = []
for _, flist in files.items():
for file in flist:
uploaded_files.append(
UploadedFileRec(
file_id=file_id,
name=file["filename"],
type=file["content_type"],
data=file["body"],
)
)
if len(uploaded_files) != 1:
self.send_error(
400, reason=f"Expected 1 file, but got {len(uploaded_files)}"
)
return
self._file_mgr.add_file(session_id=session_id, file=uploaded_files[0])
self.set_status(204)
def delete(self, **kwargs):
"""Delete file request handler."""
session_id = self.path_kwargs["session_id"]
file_id = self.path_kwargs["file_id"]
self._file_mgr.remove_file(session_id=session_id, file_id=file_id)
self.set_status(204)

View File

@@ -0,0 +1,56 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from streamlit import runtime
from streamlit.deprecation_util import show_deprecation_warning
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx
from streamlit.web.server.browser_websocket_handler import BrowserWebSocketHandler
_GET_WEBSOCKET_HEADERS_DEPRECATE_MSG = (
"The `_get_websocket_headers` function is deprecated and will be removed "
"in a future version of Streamlit. Please use `st.context.headers` instead."
)
@gather_metrics("_get_websocket_headers")
def _get_websocket_headers() -> dict[str, str] | None:
"""Return a copy of the HTTP request headers for the current session's
WebSocket connection. If there's no active session, return None instead.
Raise an error if the server is not running.
Note to the intrepid: this is an UNSUPPORTED, INTERNAL API. (We don't have plans
to remove it without a replacement, but we don't consider this a production-ready
function, and its signature may change without a deprecation warning.)
"""
show_deprecation_warning(_GET_WEBSOCKET_HEADERS_DEPRECATE_MSG)
ctx = get_script_run_ctx()
if ctx is None:
return None
session_client = runtime.get_instance().get_client(ctx.session_id)
if session_client is None:
return None
if not isinstance(session_client, BrowserWebSocketHandler):
raise RuntimeError(
f"SessionClient is not a BrowserWebSocketHandler! ({session_client})"
)
return dict(session_client.request.headers)