Mise à jour de Monitor.py et autres scripts
This commit is contained in:
@@ -0,0 +1,255 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""MySQL Connector/Python - MySQL driver written in Python."""
|
||||
|
||||
__all__ = ["CMySQLConnection", "MySQLConnection", "connect"]
|
||||
|
||||
import random
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ..constants import DEFAULT_CONFIGURATION
|
||||
from ..errors import Error, InterfaceError, ProgrammingError
|
||||
from ..pooling import ERROR_NO_CEXT
|
||||
from .abstracts import MySQLConnectionAbstract
|
||||
from .connection import MySQLConnection
|
||||
|
||||
try:
|
||||
import dns.exception
|
||||
import dns.resolver
|
||||
except ImportError:
|
||||
HAVE_DNSPYTHON = False
|
||||
else:
|
||||
HAVE_DNSPYTHON = True
|
||||
|
||||
|
||||
try:
|
||||
from .connection_cext import CMySQLConnection
|
||||
except ImportError:
|
||||
CMySQLConnection = None
|
||||
|
||||
|
||||
async def connect(*args: Any, **kwargs: Any) -> MySQLConnectionAbstract:
|
||||
"""Creates or gets a MySQL connection object.
|
||||
|
||||
In its simpliest form, `connect()` will open a connection to a
|
||||
MySQL server and return a `MySQLConnectionAbstract` subclass
|
||||
object such as `MySQLConnection` or `CMySQLConnection`.
|
||||
|
||||
When any connection pooling arguments are given, for example `pool_name`
|
||||
or `pool_size`, a pool is created or a previously one is used to return
|
||||
a `PooledMySQLConnection`.
|
||||
|
||||
Args:
|
||||
*args: N/A.
|
||||
**kwargs: For a complete list of possible arguments, see [1]. If no arguments
|
||||
are given, it uses the already configured or default values.
|
||||
|
||||
Returns:
|
||||
A `MySQLConnectionAbstract` subclass instance (such as `MySQLConnection` or
|
||||
a `CMySQLConnection`) instance.
|
||||
|
||||
Examples:
|
||||
A connection with the MySQL server can be established using either the
|
||||
`mysql.connector.connect()` method or a `MySQLConnectionAbstract` subclass:
|
||||
```
|
||||
>>> from mysql.connector.aio import MySQLConnection, HAVE_CEXT
|
||||
>>>
|
||||
>>> cnx1 = await mysql.connector.aio.connect(user='joe', database='test')
|
||||
>>> cnx2 = MySQLConnection(user='joe', database='test')
|
||||
>>> await cnx2.connect()
|
||||
>>>
|
||||
>>> cnx3 = None
|
||||
>>> if HAVE_CEXT:
|
||||
>>> from mysql.connector.aio import CMySQLConnection
|
||||
>>> cnx3 = CMySQLConnection(user='joe', database='test')
|
||||
```
|
||||
|
||||
References:
|
||||
[1]: https://dev.mysql.com/doc/connector-python/en/connector-python-connectargs.html
|
||||
"""
|
||||
# DNS SRV
|
||||
dns_srv = kwargs.pop("dns_srv") if "dns_srv" in kwargs else False
|
||||
|
||||
if not isinstance(dns_srv, bool):
|
||||
raise InterfaceError("The value of 'dns-srv' must be a boolean")
|
||||
|
||||
if dns_srv:
|
||||
if not HAVE_DNSPYTHON:
|
||||
raise InterfaceError(
|
||||
"MySQL host configuration requested DNS "
|
||||
"SRV. This requires the Python dnspython "
|
||||
"module. Please refer to documentation"
|
||||
)
|
||||
if "unix_socket" in kwargs:
|
||||
raise InterfaceError(
|
||||
"Using Unix domain sockets with DNS SRV lookup is not allowed"
|
||||
)
|
||||
if "port" in kwargs:
|
||||
raise InterfaceError(
|
||||
"Specifying a port number with DNS SRV lookup is not allowed"
|
||||
)
|
||||
if "failover" in kwargs:
|
||||
raise InterfaceError(
|
||||
"Specifying multiple hostnames with DNS SRV look up is not allowed"
|
||||
)
|
||||
if "host" not in kwargs:
|
||||
kwargs["host"] = DEFAULT_CONFIGURATION["host"]
|
||||
|
||||
try:
|
||||
srv_records = dns.resolver.query(kwargs["host"], "SRV")
|
||||
except dns.exception.DNSException:
|
||||
raise InterfaceError(
|
||||
f"Unable to locate any hosts for '{kwargs['host']}'"
|
||||
) from None
|
||||
|
||||
failover = []
|
||||
for srv in srv_records:
|
||||
failover.append(
|
||||
{
|
||||
"host": srv.target.to_text(omit_final_dot=True),
|
||||
"port": srv.port,
|
||||
"priority": srv.priority,
|
||||
"weight": srv.weight,
|
||||
}
|
||||
)
|
||||
|
||||
failover.sort(key=lambda x: (x["priority"], -x["weight"]))
|
||||
kwargs["failover"] = [
|
||||
{"host": srv["host"], "port": srv["port"]} for srv in failover
|
||||
]
|
||||
|
||||
# Failover
|
||||
if "failover" in kwargs:
|
||||
return await _get_failover_connection(**kwargs)
|
||||
|
||||
# Use C Extension by default
|
||||
use_pure = kwargs.get("use_pure", False)
|
||||
if "use_pure" in kwargs:
|
||||
del kwargs["use_pure"] # Remove 'use_pure' from kwargs
|
||||
if not use_pure and CMySQLConnection is None:
|
||||
raise ImportError(ERROR_NO_CEXT)
|
||||
|
||||
if CMySQLConnection and not use_pure:
|
||||
cnx = CMySQLConnection(*args, **kwargs)
|
||||
else:
|
||||
cnx = MySQLConnection(*args, **kwargs)
|
||||
await cnx.connect()
|
||||
return cnx
|
||||
|
||||
|
||||
async def _get_failover_connection(**kwargs: Any) -> MySQLConnectionAbstract:
|
||||
"""Return a MySQL connection and try to failover if needed.
|
||||
|
||||
An InterfaceError is raise when no MySQL is available. ValueError is
|
||||
raised when the failover server configuration contains an illegal
|
||||
connection argument. Supported arguments are user, password, host, port,
|
||||
unix_socket and database. ValueError is also raised when the failover
|
||||
argument was not provided.
|
||||
|
||||
Returns MySQLConnection instance.
|
||||
"""
|
||||
config = kwargs.copy()
|
||||
try:
|
||||
failover = config["failover"]
|
||||
except KeyError:
|
||||
raise ValueError("failover argument not provided") from None
|
||||
del config["failover"]
|
||||
|
||||
support_cnx_args = set(
|
||||
[
|
||||
"user",
|
||||
"password",
|
||||
"host",
|
||||
"port",
|
||||
"unix_socket",
|
||||
"database",
|
||||
"pool_name",
|
||||
"pool_size",
|
||||
"priority",
|
||||
]
|
||||
)
|
||||
|
||||
# First check if we can add all use the configuration
|
||||
priority_count = 0
|
||||
for server in failover:
|
||||
diff = set(server.keys()) - support_cnx_args
|
||||
if diff:
|
||||
arg = "s" if len(diff) > 1 else ""
|
||||
lst = ", ".join(diff)
|
||||
raise ValueError(
|
||||
f"Unsupported connection argument {arg} in failover: {lst}"
|
||||
)
|
||||
if hasattr(server, "priority"):
|
||||
priority_count += 1
|
||||
|
||||
server["priority"] = server.get("priority", 100)
|
||||
if server["priority"] < 0 or server["priority"] > 100:
|
||||
raise InterfaceError(
|
||||
"Priority value should be in the range of 0 to 100, "
|
||||
f"got : {server['priority']}"
|
||||
)
|
||||
if not isinstance(server["priority"], int):
|
||||
raise InterfaceError(
|
||||
"Priority value should be an integer in the range of 0 to "
|
||||
f"100, got : {server['priority']}"
|
||||
)
|
||||
|
||||
if 0 < priority_count < len(failover):
|
||||
raise ProgrammingError(
|
||||
"You must either assign no priority to any "
|
||||
"of the routers or give a priority for "
|
||||
"every router"
|
||||
)
|
||||
|
||||
server_directory = {}
|
||||
server_priority_list = []
|
||||
for server in sorted(failover, key=lambda x: x["priority"], reverse=True):
|
||||
if server["priority"] not in server_directory:
|
||||
server_directory[server["priority"]] = [server]
|
||||
server_priority_list.append(server["priority"])
|
||||
else:
|
||||
server_directory[server["priority"]].append(server)
|
||||
|
||||
for priority in server_priority_list:
|
||||
failover_list = server_directory[priority]
|
||||
for _ in range(len(failover_list)):
|
||||
last = len(failover_list) - 1
|
||||
index = random.randint(0, last)
|
||||
server = failover_list.pop(index)
|
||||
new_config = config.copy()
|
||||
new_config.update(server)
|
||||
new_config.pop("priority", None)
|
||||
try:
|
||||
return await connect(**new_config)
|
||||
except Error:
|
||||
# If we failed to connect, we try the next server
|
||||
pass
|
||||
|
||||
raise InterfaceError("Unable to connect to any of the target hosts")
|
||||
@@ -0,0 +1,70 @@
|
||||
# Copyright (c) 2009, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Decorators Hub."""
|
||||
|
||||
import functools
|
||||
import warnings
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from ..constants import RefreshOption
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .abstracts import MySQLConnectionAbstract
|
||||
|
||||
|
||||
def cmd_refresh_verify_options() -> Callable:
|
||||
"""Decorator verifying which options are relevant and which aren't based on
|
||||
the server version the client is connecting to."""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(
|
||||
cnx: "MySQLConnectionAbstract", *args: Any, **kwargs: Any
|
||||
) -> Callable:
|
||||
options: int = args[0]
|
||||
if (options & RefreshOption.GRANT) and cnx.get_server_version() >= (
|
||||
9,
|
||||
2,
|
||||
0,
|
||||
):
|
||||
warnings.warn(
|
||||
"As of MySQL Server 9.2.0, refreshing grant tables is not needed "
|
||||
"if you use statements GRANT, REVOKE, CREATE, DROP, or ALTER. "
|
||||
"You should expect this option to be unsupported in a future "
|
||||
"version of MySQL Connector/Python when MySQL Server removes it.",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
return await func(cnx, options, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
2442
myenv/lib/python3.11/site-packages/mysql/connector/aio/abstracts.py
Normal file
2442
myenv/lib/python3.11/site-packages/mysql/connector/aio/abstracts.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,335 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Implementing support for MySQL Authentication Plugins."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ["MySQLAuthenticator"]
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from ..errors import InterfaceError, NotSupportedError, get_exception
|
||||
from ..protocol import (
|
||||
AUTH_SWITCH_STATUS,
|
||||
DEFAULT_CHARSET_ID,
|
||||
DEFAULT_MAX_ALLOWED_PACKET,
|
||||
ERR_STATUS,
|
||||
EXCHANGE_FURTHER_STATUS,
|
||||
MFA_STATUS,
|
||||
OK_STATUS,
|
||||
)
|
||||
from ..types import HandShakeType
|
||||
from .logger import logger
|
||||
from .plugins import MySQLAuthPlugin, get_auth_plugin
|
||||
from .protocol import MySQLProtocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .network import MySQLSocket
|
||||
|
||||
|
||||
class MySQLAuthenticator:
|
||||
"""Implements the authentication phase."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Constructor."""
|
||||
self._username: str = ""
|
||||
self._passwords: Dict[int, str] = {}
|
||||
self._plugin_config: Dict[str, Any] = {}
|
||||
self._ssl_enabled: bool = False
|
||||
self._auth_strategy: Optional[MySQLAuthPlugin] = None
|
||||
self._auth_plugin_class: Optional[str] = None
|
||||
|
||||
@property
|
||||
def ssl_enabled(self) -> bool:
|
||||
"""Signals whether or not SSL is enabled."""
|
||||
return self._ssl_enabled
|
||||
|
||||
@property
|
||||
def plugin_config(self) -> Dict[str, Any]:
|
||||
"""Custom arguments that are being provided to the authentication plugin.
|
||||
|
||||
The parameters defined here will override the ones defined in the
|
||||
auth plugin itself.
|
||||
|
||||
The plugin config is a read-only property - the plugin configuration
|
||||
provided when invoking `authenticate()` is recorded and can be queried
|
||||
by accessing this property.
|
||||
|
||||
Returns:
|
||||
dict: The latest plugin configuration provided when invoking
|
||||
`authenticate()`.
|
||||
"""
|
||||
return self._plugin_config
|
||||
|
||||
def update_plugin_config(self, config: Dict[str, Any]) -> None:
|
||||
"""Update the 'plugin_config' instance variable"""
|
||||
self._plugin_config.update(config)
|
||||
|
||||
def _switch_auth_strategy(
|
||||
self,
|
||||
new_strategy_name: str,
|
||||
strategy_class: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
password_factor: int = 1,
|
||||
) -> None:
|
||||
"""Switch the authorization plugin.
|
||||
|
||||
Args:
|
||||
new_strategy_name: New authorization plugin name to switch to.
|
||||
strategy_class: New authorization plugin class to switch to
|
||||
(has higher precedence than the authorization plugin name).
|
||||
username: Username to be used - if not defined, the username
|
||||
provided when `authentication()` was invoked is used.
|
||||
password_factor: Up to three levels of authentication (MFA) are allowed,
|
||||
hence you can choose the password corresponding to the 1st,
|
||||
2nd, or 3rd factor - 1st is the default.
|
||||
"""
|
||||
if username is None:
|
||||
username = self._username
|
||||
|
||||
if strategy_class is None:
|
||||
strategy_class = self._auth_plugin_class
|
||||
|
||||
logger.debug("Switching to strategy %s", new_strategy_name)
|
||||
self._auth_strategy = get_auth_plugin(
|
||||
plugin_name=new_strategy_name, auth_plugin_class=strategy_class
|
||||
)(
|
||||
username,
|
||||
self._passwords.get(password_factor, ""),
|
||||
ssl_enabled=self.ssl_enabled,
|
||||
)
|
||||
|
||||
async def _mfa_n_factor(
|
||||
self,
|
||||
sock: MySQLSocket,
|
||||
pkt: bytes,
|
||||
) -> Optional[bytes]:
|
||||
"""Handle MFA (Multi-Factor Authentication) response.
|
||||
|
||||
Up to three levels of authentication (MFA) are allowed.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
pkt: MFA response.
|
||||
|
||||
Returns:
|
||||
ok_packet: If last server's response is an OK packet.
|
||||
None: If last server's response isn't an OK packet and no ERROR was raised.
|
||||
|
||||
Raises:
|
||||
InterfaceError: If got an invalid N factor.
|
||||
errors.ErrorTypes: If got an ERROR response.
|
||||
"""
|
||||
n_factor = 2
|
||||
while pkt[4] == MFA_STATUS:
|
||||
if n_factor not in self._passwords:
|
||||
raise InterfaceError(
|
||||
"Failed Multi Factor Authentication (invalid N factor)"
|
||||
)
|
||||
|
||||
new_strategy_name, auth_data = MySQLProtocol.parse_auth_next_factor(pkt)
|
||||
self._switch_auth_strategy(new_strategy_name, password_factor=n_factor)
|
||||
logger.debug("MFA %i factor %s", n_factor, self._auth_strategy.name)
|
||||
|
||||
pkt = await self._auth_strategy.auth_switch_response(
|
||||
sock, auth_data, **self._plugin_config
|
||||
)
|
||||
|
||||
if pkt[4] == EXCHANGE_FURTHER_STATUS:
|
||||
auth_data = MySQLProtocol.parse_auth_more_data(pkt)
|
||||
pkt = await self._auth_strategy.auth_more_response(
|
||||
sock, auth_data, **self._plugin_config
|
||||
)
|
||||
|
||||
if pkt[4] == OK_STATUS:
|
||||
logger.debug("MFA completed succesfully")
|
||||
return pkt
|
||||
|
||||
if pkt[4] == ERR_STATUS:
|
||||
raise get_exception(pkt)
|
||||
|
||||
n_factor += 1
|
||||
|
||||
logger.warning("MFA terminated with a no ok packet")
|
||||
return None
|
||||
|
||||
async def _handle_server_response(
|
||||
self,
|
||||
sock: MySQLSocket,
|
||||
pkt: bytes,
|
||||
) -> Optional[bytes]:
|
||||
"""Handle server's response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
pkt: Server's response after completing the `HandShakeResponse`.
|
||||
|
||||
Returns:
|
||||
ok_packet: If last server's response is an OK packet.
|
||||
None: If last server's response isn't an OK packet and no ERROR was raised.
|
||||
|
||||
Raises:
|
||||
errors.ErrorTypes: If got an ERROR response.
|
||||
NotSupportedError: If got Authentication with old (insecure) passwords.
|
||||
"""
|
||||
if pkt[4] == AUTH_SWITCH_STATUS and len(pkt) == 5:
|
||||
raise NotSupportedError(
|
||||
"Authentication with old (insecure) passwords "
|
||||
"is not supported. For more information, lookup "
|
||||
"Password Hashing in the latest MySQL manual"
|
||||
)
|
||||
|
||||
if pkt[4] == AUTH_SWITCH_STATUS:
|
||||
logger.debug("Server's response is an auth switch request")
|
||||
new_strategy_name, auth_data = MySQLProtocol.parse_auth_switch_request(pkt)
|
||||
self._switch_auth_strategy(new_strategy_name)
|
||||
pkt = await self._auth_strategy.auth_switch_response(
|
||||
sock, auth_data, **self._plugin_config
|
||||
)
|
||||
|
||||
if pkt[4] == EXCHANGE_FURTHER_STATUS:
|
||||
logger.debug("Exchanging further packets")
|
||||
auth_data = MySQLProtocol.parse_auth_more_data(pkt)
|
||||
pkt = await self._auth_strategy.auth_more_response(
|
||||
sock, auth_data, **self._plugin_config
|
||||
)
|
||||
|
||||
if pkt[4] == OK_STATUS:
|
||||
logger.debug("%s completed succesfully", self._auth_strategy.name)
|
||||
return pkt
|
||||
|
||||
if pkt[4] == MFA_STATUS:
|
||||
logger.debug("Starting multi-factor authentication")
|
||||
logger.debug("MFA 1 factor %s", self._auth_strategy.name)
|
||||
return await self._mfa_n_factor(sock, pkt)
|
||||
|
||||
if pkt[4] == ERR_STATUS:
|
||||
raise get_exception(pkt)
|
||||
|
||||
return None
|
||||
|
||||
async def authenticate(
|
||||
self,
|
||||
sock: MySQLSocket,
|
||||
handshake: HandShakeType,
|
||||
username: str = "",
|
||||
password1: str = "",
|
||||
password2: str = "",
|
||||
password3: str = "",
|
||||
database: Optional[str] = None,
|
||||
charset: int = DEFAULT_CHARSET_ID,
|
||||
client_flags: int = 0,
|
||||
ssl_enabled: bool = False,
|
||||
max_allowed_packet: int = DEFAULT_MAX_ALLOWED_PACKET,
|
||||
auth_plugin: Optional[str] = None,
|
||||
auth_plugin_class: Optional[str] = None,
|
||||
conn_attrs: Optional[Dict[str, str]] = None,
|
||||
is_change_user_request: bool = False,
|
||||
read_timeout: Optional[int] = None,
|
||||
write_timeout: Optional[int] = None,
|
||||
) -> bytes:
|
||||
"""Perform the authentication phase.
|
||||
|
||||
During re-authentication you must set `is_change_user_request` to True.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
handshake: Initial handshake.
|
||||
username: Account's username.
|
||||
password1: Account's password factor 1.
|
||||
password2: Account's password factor 2.
|
||||
password3: Account's password factor 3.
|
||||
database: Initial database name for the connection.
|
||||
charset: Client charset (see [1]), only the lower 8-bits.
|
||||
client_flags: Integer representing client capabilities flags.
|
||||
ssl_enabled: Boolean indicating whether SSL is enabled,
|
||||
max_allowed_packet: Maximum packet size.
|
||||
auth_plugin: Authorization plugin name.
|
||||
auth_plugin_class: Authorization plugin class (has higher precedence
|
||||
than the authorization plugin name).
|
||||
conn_attrs: Connection attributes.
|
||||
is_change_user_request: Whether is a `change user request` operation or not.
|
||||
read_timeout: Timeout in seconds upto which the connector should wait for
|
||||
the server to reply back before raising an ReadTimeoutError.
|
||||
write_timeout: Timeout in seconds upto which the connector should spend to
|
||||
send data to the server before raising an WriteTimeoutError.
|
||||
|
||||
Returns:
|
||||
ok_packet: OK packet.
|
||||
|
||||
Raises:
|
||||
InterfaceError: If OK packet is NULL.
|
||||
ReadTimeoutError: If the time taken for the server to reply back exceeds
|
||||
'read_timeout' (if set).
|
||||
WriteTimeoutError: If the time taken to send data packets to the server
|
||||
exceeds 'write_timeout' (if set).
|
||||
|
||||
References:
|
||||
[1]: https://dev.mysql.com/doc/dev/mysql-server/latest/\
|
||||
page_protocol_basic_character_set.html#a_protocol_character_set
|
||||
"""
|
||||
# update credentials, plugin config and plugin class
|
||||
self._username = username
|
||||
self._passwords = {1: password1, 2: password2, 3: password3}
|
||||
self._ssl_enabled = ssl_enabled
|
||||
self._auth_plugin_class = auth_plugin_class
|
||||
|
||||
# client's handshake response
|
||||
response_payload, self._auth_strategy = MySQLProtocol.make_auth(
|
||||
handshake=handshake,
|
||||
username=username,
|
||||
password=password1,
|
||||
database=database,
|
||||
charset=charset,
|
||||
client_flags=client_flags,
|
||||
max_allowed_packet=max_allowed_packet,
|
||||
auth_plugin=auth_plugin,
|
||||
auth_plugin_class=auth_plugin_class,
|
||||
conn_attrs=conn_attrs,
|
||||
is_change_user_request=is_change_user_request,
|
||||
ssl_enabled=self.ssl_enabled,
|
||||
plugin_config=self.plugin_config,
|
||||
)
|
||||
|
||||
# client sends transaction response
|
||||
send_args = (
|
||||
(0, 0, write_timeout)
|
||||
if is_change_user_request
|
||||
else (None, None, write_timeout)
|
||||
)
|
||||
await sock.write(response_payload, *send_args)
|
||||
|
||||
# server replies back
|
||||
pkt = bytes(await sock.read(read_timeout))
|
||||
|
||||
ok_pkt = await self._handle_server_response(sock, pkt)
|
||||
if ok_pkt is None:
|
||||
raise InterfaceError("Got a NULL ok_pkt") from None
|
||||
|
||||
return ok_pkt
|
||||
@@ -0,0 +1,686 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""This module contains the MySQL Server Character Sets."""
|
||||
|
||||
__all__ = ["Charset", "charsets"]
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import DefaultDict, Dict, Optional, Sequence, Tuple
|
||||
|
||||
from ..errors import ProgrammingError
|
||||
|
||||
|
||||
@dataclass
|
||||
class Charset:
|
||||
"""Dataclass representing a character set."""
|
||||
|
||||
charset_id: int
|
||||
name: str
|
||||
collation: str
|
||||
is_default: bool
|
||||
|
||||
|
||||
class Charsets:
|
||||
"""MySQL supported character sets and collations class.
|
||||
|
||||
This class holds the list of character sets with their collations supported by
|
||||
MySQL, making available methods to get character sets by name, collation, or ID.
|
||||
It uses a sparse matrix or tree-like representation using a dict in a dict to hold
|
||||
the character set name and collations combinations.
|
||||
The list is hardcoded, so we avoid a database query when getting the name of the
|
||||
used character set or collation.
|
||||
|
||||
The call of ``charsets.set_mysql_major_version()`` should be done before using any
|
||||
of the retrieval methods.
|
||||
|
||||
Usage:
|
||||
>>> from mysql.connector.aio.charsets import charsets
|
||||
>>> charsets.set_mysql_major_version(8)
|
||||
>>> charsets.get_by_name("utf-8")
|
||||
Charset(charset_id=255,
|
||||
name='utf8mb4',
|
||||
collation='utf8mb4_0900_ai_ci',
|
||||
is_default=True)
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._charset_id_store: Dict[int, Charset] = {}
|
||||
self._collation_store: Dict[str, Charset] = {}
|
||||
self._name_store: DefaultDict[str, Dict[str, Charset]] = defaultdict(dict)
|
||||
self._mysql_major_version: Optional[int] = None
|
||||
|
||||
def set_mysql_major_version(self, version: int) -> None:
|
||||
"""Set the MySQL major version.
|
||||
|
||||
Sets what tuple should be used based on the MySQL major version to store the
|
||||
list of character sets and collations.
|
||||
|
||||
Args:
|
||||
version: The MySQL major version (i.e. 8 or 5)
|
||||
"""
|
||||
self._mysql_major_version = version
|
||||
self._charset_id_store.clear()
|
||||
self._collation_store.clear()
|
||||
self._name_store.clear()
|
||||
|
||||
charsets_tuple: Sequence[Tuple[int, str, str, bool]] = None
|
||||
if version >= 8:
|
||||
charsets_tuple = MYSQL_8_CHARSETS
|
||||
elif version == 5:
|
||||
charsets_tuple = MYSQL_5_CHARSETS
|
||||
else:
|
||||
raise ProgrammingError("Invalid MySQL major version")
|
||||
|
||||
for charset_id, name, collation, is_default in charsets_tuple:
|
||||
charset = Charset(charset_id, name, collation, is_default)
|
||||
self._charset_id_store[charset_id] = charset
|
||||
self._collation_store[collation] = charset
|
||||
self._name_store[name][collation] = charset
|
||||
|
||||
def get_by_id(self, charset_id: int) -> Charset:
|
||||
"""Get character set by ID.
|
||||
|
||||
Args:
|
||||
charset_id: The charset ID.
|
||||
|
||||
Returns:
|
||||
Charset: The Charset dataclass instance.
|
||||
"""
|
||||
try:
|
||||
return self._charset_id_store[charset_id]
|
||||
except KeyError as err:
|
||||
raise ProgrammingError(f"Character set ID {charset_id} unknown") from err
|
||||
|
||||
def get_by_collation(self, collation: str) -> Charset:
|
||||
"""Get character set by collation.
|
||||
|
||||
Args:
|
||||
collation: The collation name.
|
||||
|
||||
Returns:
|
||||
Charset: The Charset dataclass instance.
|
||||
"""
|
||||
try:
|
||||
return self._collation_store[collation]
|
||||
except KeyError as err:
|
||||
raise ProgrammingError(f"Collation {collation} unknown") from err
|
||||
|
||||
def get_by_name(self, name: str) -> Charset:
|
||||
"""Get character set by name.
|
||||
|
||||
Args:
|
||||
name: The charset name.
|
||||
|
||||
Returns:
|
||||
Charset: The Charset dataclass instance.
|
||||
"""
|
||||
try:
|
||||
if name in ("utf8", "utf-8") and self._mysql_major_version == 8:
|
||||
name = "utf8mb4"
|
||||
for charset in self._name_store[name].values():
|
||||
if charset.is_default:
|
||||
return charset
|
||||
except KeyError as err:
|
||||
raise ProgrammingError(f"Character set name {name} unknown") from err
|
||||
raise ProgrammingError(f"No default was found for character set '{name}'")
|
||||
|
||||
def get_by_name_and_collation(self, name: str, collation: str) -> Charset:
|
||||
"""Get character set by name and collation.
|
||||
|
||||
Args:
|
||||
name: The charset name.
|
||||
collation: The collation name.
|
||||
|
||||
Returns:
|
||||
Charset: The Charset dataclass instance.
|
||||
"""
|
||||
try:
|
||||
return self._name_store[name][collation]
|
||||
except KeyError as err:
|
||||
raise ProgrammingError(
|
||||
f"Character set name '{name}' with collation '{collation}' not found"
|
||||
) from err
|
||||
|
||||
|
||||
MYSQL_8_CHARSETS = (
|
||||
(1, "big5", "big5_chinese_ci", True),
|
||||
(2, "latin2", "latin2_czech_cs", False),
|
||||
(3, "dec8", "dec8_swedish_ci", True),
|
||||
(4, "cp850", "cp850_general_ci", True),
|
||||
(5, "latin1", "latin1_german1_ci", False),
|
||||
(6, "hp8", "hp8_english_ci", True),
|
||||
(7, "koi8r", "koi8r_general_ci", True),
|
||||
(8, "latin1", "latin1_swedish_ci", True),
|
||||
(9, "latin2", "latin2_general_ci", True),
|
||||
(10, "swe7", "swe7_swedish_ci", True),
|
||||
(11, "ascii", "ascii_general_ci", True),
|
||||
(12, "ujis", "ujis_japanese_ci", True),
|
||||
(13, "sjis", "sjis_japanese_ci", True),
|
||||
(14, "cp1251", "cp1251_bulgarian_ci", False),
|
||||
(15, "latin1", "latin1_danish_ci", False),
|
||||
(16, "hebrew", "hebrew_general_ci", True),
|
||||
(18, "tis620", "tis620_thai_ci", True),
|
||||
(19, "euckr", "euckr_korean_ci", True),
|
||||
(20, "latin7", "latin7_estonian_cs", False),
|
||||
(21, "latin2", "latin2_hungarian_ci", False),
|
||||
(22, "koi8u", "koi8u_general_ci", True),
|
||||
(23, "cp1251", "cp1251_ukrainian_ci", False),
|
||||
(24, "gb2312", "gb2312_chinese_ci", True),
|
||||
(25, "greek", "greek_general_ci", True),
|
||||
(26, "cp1250", "cp1250_general_ci", True),
|
||||
(27, "latin2", "latin2_croatian_ci", False),
|
||||
(28, "gbk", "gbk_chinese_ci", True),
|
||||
(29, "cp1257", "cp1257_lithuanian_ci", False),
|
||||
(30, "latin5", "latin5_turkish_ci", True),
|
||||
(31, "latin1", "latin1_german2_ci", False),
|
||||
(32, "armscii8", "armscii8_general_ci", True),
|
||||
(33, "utf8mb3", "utf8mb3_general_ci", True),
|
||||
(34, "cp1250", "cp1250_czech_cs", False),
|
||||
(35, "ucs2", "ucs2_general_ci", True),
|
||||
(36, "cp866", "cp866_general_ci", True),
|
||||
(37, "keybcs2", "keybcs2_general_ci", True),
|
||||
(38, "macce", "macce_general_ci", True),
|
||||
(39, "macroman", "macroman_general_ci", True),
|
||||
(40, "cp852", "cp852_general_ci", True),
|
||||
(41, "latin7", "latin7_general_ci", True),
|
||||
(42, "latin7", "latin7_general_cs", False),
|
||||
(43, "macce", "macce_bin", False),
|
||||
(44, "cp1250", "cp1250_croatian_ci", False),
|
||||
(45, "utf8mb4", "utf8mb4_general_ci", False),
|
||||
(46, "utf8mb4", "utf8mb4_bin", False),
|
||||
(47, "latin1", "latin1_bin", False),
|
||||
(48, "latin1", "latin1_general_ci", False),
|
||||
(49, "latin1", "latin1_general_cs", False),
|
||||
(50, "cp1251", "cp1251_bin", False),
|
||||
(51, "cp1251", "cp1251_general_ci", True),
|
||||
(52, "cp1251", "cp1251_general_cs", False),
|
||||
(53, "macroman", "macroman_bin", False),
|
||||
(54, "utf16", "utf16_general_ci", True),
|
||||
(55, "utf16", "utf16_bin", False),
|
||||
(56, "utf16le", "utf16le_general_ci", True),
|
||||
(57, "cp1256", "cp1256_general_ci", True),
|
||||
(58, "cp1257", "cp1257_bin", False),
|
||||
(59, "cp1257", "cp1257_general_ci", True),
|
||||
(60, "utf32", "utf32_general_ci", True),
|
||||
(61, "utf32", "utf32_bin", False),
|
||||
(62, "utf16le", "utf16le_bin", False),
|
||||
(63, "binary", "binary", True),
|
||||
(64, "armscii8", "armscii8_bin", False),
|
||||
(65, "ascii", "ascii_bin", False),
|
||||
(66, "cp1250", "cp1250_bin", False),
|
||||
(67, "cp1256", "cp1256_bin", False),
|
||||
(68, "cp866", "cp866_bin", False),
|
||||
(69, "dec8", "dec8_bin", False),
|
||||
(70, "greek", "greek_bin", False),
|
||||
(71, "hebrew", "hebrew_bin", False),
|
||||
(72, "hp8", "hp8_bin", False),
|
||||
(73, "keybcs2", "keybcs2_bin", False),
|
||||
(74, "koi8r", "koi8r_bin", False),
|
||||
(75, "koi8u", "koi8u_bin", False),
|
||||
(76, "utf8mb3", "utf8mb3_tolower_ci", False),
|
||||
(77, "latin2", "latin2_bin", False),
|
||||
(78, "latin5", "latin5_bin", False),
|
||||
(79, "latin7", "latin7_bin", False),
|
||||
(80, "cp850", "cp850_bin", False),
|
||||
(81, "cp852", "cp852_bin", False),
|
||||
(82, "swe7", "swe7_bin", False),
|
||||
(83, "utf8mb3", "utf8mb3_bin", False),
|
||||
(84, "big5", "big5_bin", False),
|
||||
(85, "euckr", "euckr_bin", False),
|
||||
(86, "gb2312", "gb2312_bin", False),
|
||||
(87, "gbk", "gbk_bin", False),
|
||||
(88, "sjis", "sjis_bin", False),
|
||||
(89, "tis620", "tis620_bin", False),
|
||||
(90, "ucs2", "ucs2_bin", False),
|
||||
(91, "ujis", "ujis_bin", False),
|
||||
(92, "geostd8", "geostd8_general_ci", True),
|
||||
(93, "geostd8", "geostd8_bin", False),
|
||||
(94, "latin1", "latin1_spanish_ci", False),
|
||||
(95, "cp932", "cp932_japanese_ci", True),
|
||||
(96, "cp932", "cp932_bin", False),
|
||||
(97, "eucjpms", "eucjpms_japanese_ci", True),
|
||||
(98, "eucjpms", "eucjpms_bin", False),
|
||||
(99, "cp1250", "cp1250_polish_ci", False),
|
||||
(101, "utf16", "utf16_unicode_ci", False),
|
||||
(102, "utf16", "utf16_icelandic_ci", False),
|
||||
(103, "utf16", "utf16_latvian_ci", False),
|
||||
(104, "utf16", "utf16_romanian_ci", False),
|
||||
(105, "utf16", "utf16_slovenian_ci", False),
|
||||
(106, "utf16", "utf16_polish_ci", False),
|
||||
(107, "utf16", "utf16_estonian_ci", False),
|
||||
(108, "utf16", "utf16_spanish_ci", False),
|
||||
(109, "utf16", "utf16_swedish_ci", False),
|
||||
(110, "utf16", "utf16_turkish_ci", False),
|
||||
(111, "utf16", "utf16_czech_ci", False),
|
||||
(112, "utf16", "utf16_danish_ci", False),
|
||||
(113, "utf16", "utf16_lithuanian_ci", False),
|
||||
(114, "utf16", "utf16_slovak_ci", False),
|
||||
(115, "utf16", "utf16_spanish2_ci", False),
|
||||
(116, "utf16", "utf16_roman_ci", False),
|
||||
(117, "utf16", "utf16_persian_ci", False),
|
||||
(118, "utf16", "utf16_esperanto_ci", False),
|
||||
(119, "utf16", "utf16_hungarian_ci", False),
|
||||
(120, "utf16", "utf16_sinhala_ci", False),
|
||||
(121, "utf16", "utf16_german2_ci", False),
|
||||
(122, "utf16", "utf16_croatian_ci", False),
|
||||
(123, "utf16", "utf16_unicode_520_ci", False),
|
||||
(124, "utf16", "utf16_vietnamese_ci", False),
|
||||
(128, "ucs2", "ucs2_unicode_ci", False),
|
||||
(129, "ucs2", "ucs2_icelandic_ci", False),
|
||||
(130, "ucs2", "ucs2_latvian_ci", False),
|
||||
(131, "ucs2", "ucs2_romanian_ci", False),
|
||||
(132, "ucs2", "ucs2_slovenian_ci", False),
|
||||
(133, "ucs2", "ucs2_polish_ci", False),
|
||||
(134, "ucs2", "ucs2_estonian_ci", False),
|
||||
(135, "ucs2", "ucs2_spanish_ci", False),
|
||||
(136, "ucs2", "ucs2_swedish_ci", False),
|
||||
(137, "ucs2", "ucs2_turkish_ci", False),
|
||||
(138, "ucs2", "ucs2_czech_ci", False),
|
||||
(139, "ucs2", "ucs2_danish_ci", False),
|
||||
(140, "ucs2", "ucs2_lithuanian_ci", False),
|
||||
(141, "ucs2", "ucs2_slovak_ci", False),
|
||||
(142, "ucs2", "ucs2_spanish2_ci", False),
|
||||
(143, "ucs2", "ucs2_roman_ci", False),
|
||||
(144, "ucs2", "ucs2_persian_ci", False),
|
||||
(145, "ucs2", "ucs2_esperanto_ci", False),
|
||||
(146, "ucs2", "ucs2_hungarian_ci", False),
|
||||
(147, "ucs2", "ucs2_sinhala_ci", False),
|
||||
(148, "ucs2", "ucs2_german2_ci", False),
|
||||
(149, "ucs2", "ucs2_croatian_ci", False),
|
||||
(150, "ucs2", "ucs2_unicode_520_ci", False),
|
||||
(151, "ucs2", "ucs2_vietnamese_ci", False),
|
||||
(159, "ucs2", "ucs2_general_mysql500_ci", False),
|
||||
(160, "utf32", "utf32_unicode_ci", False),
|
||||
(161, "utf32", "utf32_icelandic_ci", False),
|
||||
(162, "utf32", "utf32_latvian_ci", False),
|
||||
(163, "utf32", "utf32_romanian_ci", False),
|
||||
(164, "utf32", "utf32_slovenian_ci", False),
|
||||
(165, "utf32", "utf32_polish_ci", False),
|
||||
(166, "utf32", "utf32_estonian_ci", False),
|
||||
(167, "utf32", "utf32_spanish_ci", False),
|
||||
(168, "utf32", "utf32_swedish_ci", False),
|
||||
(169, "utf32", "utf32_turkish_ci", False),
|
||||
(170, "utf32", "utf32_czech_ci", False),
|
||||
(171, "utf32", "utf32_danish_ci", False),
|
||||
(172, "utf32", "utf32_lithuanian_ci", False),
|
||||
(173, "utf32", "utf32_slovak_ci", False),
|
||||
(174, "utf32", "utf32_spanish2_ci", False),
|
||||
(175, "utf32", "utf32_roman_ci", False),
|
||||
(176, "utf32", "utf32_persian_ci", False),
|
||||
(177, "utf32", "utf32_esperanto_ci", False),
|
||||
(178, "utf32", "utf32_hungarian_ci", False),
|
||||
(179, "utf32", "utf32_sinhala_ci", False),
|
||||
(180, "utf32", "utf32_german2_ci", False),
|
||||
(181, "utf32", "utf32_croatian_ci", False),
|
||||
(182, "utf32", "utf32_unicode_520_ci", False),
|
||||
(183, "utf32", "utf32_vietnamese_ci", False),
|
||||
(192, "utf8mb3", "utf8mb3_unicode_ci", False),
|
||||
(193, "utf8mb3", "utf8mb3_icelandic_ci", False),
|
||||
(194, "utf8mb3", "utf8mb3_latvian_ci", False),
|
||||
(195, "utf8mb3", "utf8mb3_romanian_ci", False),
|
||||
(196, "utf8mb3", "utf8mb3_slovenian_ci", False),
|
||||
(197, "utf8mb3", "utf8mb3_polish_ci", False),
|
||||
(198, "utf8mb3", "utf8mb3_estonian_ci", False),
|
||||
(199, "utf8mb3", "utf8mb3_spanish_ci", False),
|
||||
(200, "utf8mb3", "utf8mb3_swedish_ci", False),
|
||||
(201, "utf8mb3", "utf8mb3_turkish_ci", False),
|
||||
(202, "utf8mb3", "utf8mb3_czech_ci", False),
|
||||
(203, "utf8mb3", "utf8mb3_danish_ci", False),
|
||||
(204, "utf8mb3", "utf8mb3_lithuanian_ci", False),
|
||||
(205, "utf8mb3", "utf8mb3_slovak_ci", False),
|
||||
(206, "utf8mb3", "utf8mb3_spanish2_ci", False),
|
||||
(207, "utf8mb3", "utf8mb3_roman_ci", False),
|
||||
(208, "utf8mb3", "utf8mb3_persian_ci", False),
|
||||
(209, "utf8mb3", "utf8mb3_esperanto_ci", False),
|
||||
(210, "utf8mb3", "utf8mb3_hungarian_ci", False),
|
||||
(211, "utf8mb3", "utf8mb3_sinhala_ci", False),
|
||||
(212, "utf8mb3", "utf8mb3_german2_ci", False),
|
||||
(213, "utf8mb3", "utf8mb3_croatian_ci", False),
|
||||
(214, "utf8mb3", "utf8mb3_unicode_520_ci", False),
|
||||
(215, "utf8mb3", "utf8mb3_vietnamese_ci", False),
|
||||
(223, "utf8mb3", "utf8mb3_general_mysql500_ci", False),
|
||||
(224, "utf8mb4", "utf8mb4_unicode_ci", False),
|
||||
(225, "utf8mb4", "utf8mb4_icelandic_ci", False),
|
||||
(226, "utf8mb4", "utf8mb4_latvian_ci", False),
|
||||
(227, "utf8mb4", "utf8mb4_romanian_ci", False),
|
||||
(228, "utf8mb4", "utf8mb4_slovenian_ci", False),
|
||||
(229, "utf8mb4", "utf8mb4_polish_ci", False),
|
||||
(230, "utf8mb4", "utf8mb4_estonian_ci", False),
|
||||
(231, "utf8mb4", "utf8mb4_spanish_ci", False),
|
||||
(232, "utf8mb4", "utf8mb4_swedish_ci", False),
|
||||
(233, "utf8mb4", "utf8mb4_turkish_ci", False),
|
||||
(234, "utf8mb4", "utf8mb4_czech_ci", False),
|
||||
(235, "utf8mb4", "utf8mb4_danish_ci", False),
|
||||
(236, "utf8mb4", "utf8mb4_lithuanian_ci", False),
|
||||
(237, "utf8mb4", "utf8mb4_slovak_ci", False),
|
||||
(238, "utf8mb4", "utf8mb4_spanish2_ci", False),
|
||||
(239, "utf8mb4", "utf8mb4_roman_ci", False),
|
||||
(240, "utf8mb4", "utf8mb4_persian_ci", False),
|
||||
(241, "utf8mb4", "utf8mb4_esperanto_ci", False),
|
||||
(242, "utf8mb4", "utf8mb4_hungarian_ci", False),
|
||||
(243, "utf8mb4", "utf8mb4_sinhala_ci", False),
|
||||
(244, "utf8mb4", "utf8mb4_german2_ci", False),
|
||||
(245, "utf8mb4", "utf8mb4_croatian_ci", False),
|
||||
(246, "utf8mb4", "utf8mb4_unicode_520_ci", False),
|
||||
(247, "utf8mb4", "utf8mb4_vietnamese_ci", False),
|
||||
(248, "gb18030", "gb18030_chinese_ci", True),
|
||||
(249, "gb18030", "gb18030_bin", False),
|
||||
(250, "gb18030", "gb18030_unicode_520_ci", False),
|
||||
(255, "utf8mb4", "utf8mb4_0900_ai_ci", True),
|
||||
(256, "utf8mb4", "utf8mb4_de_pb_0900_ai_ci", False),
|
||||
(257, "utf8mb4", "utf8mb4_is_0900_ai_ci", False),
|
||||
(258, "utf8mb4", "utf8mb4_lv_0900_ai_ci", False),
|
||||
(259, "utf8mb4", "utf8mb4_ro_0900_ai_ci", False),
|
||||
(260, "utf8mb4", "utf8mb4_sl_0900_ai_ci", False),
|
||||
(261, "utf8mb4", "utf8mb4_pl_0900_ai_ci", False),
|
||||
(262, "utf8mb4", "utf8mb4_et_0900_ai_ci", False),
|
||||
(263, "utf8mb4", "utf8mb4_es_0900_ai_ci", False),
|
||||
(264, "utf8mb4", "utf8mb4_sv_0900_ai_ci", False),
|
||||
(265, "utf8mb4", "utf8mb4_tr_0900_ai_ci", False),
|
||||
(266, "utf8mb4", "utf8mb4_cs_0900_ai_ci", False),
|
||||
(267, "utf8mb4", "utf8mb4_da_0900_ai_ci", False),
|
||||
(268, "utf8mb4", "utf8mb4_lt_0900_ai_ci", False),
|
||||
(269, "utf8mb4", "utf8mb4_sk_0900_ai_ci", False),
|
||||
(270, "utf8mb4", "utf8mb4_es_trad_0900_ai_ci", False),
|
||||
(271, "utf8mb4", "utf8mb4_la_0900_ai_ci", False),
|
||||
(273, "utf8mb4", "utf8mb4_eo_0900_ai_ci", False),
|
||||
(274, "utf8mb4", "utf8mb4_hu_0900_ai_ci", False),
|
||||
(275, "utf8mb4", "utf8mb4_hr_0900_ai_ci", False),
|
||||
(277, "utf8mb4", "utf8mb4_vi_0900_ai_ci", False),
|
||||
(278, "utf8mb4", "utf8mb4_0900_as_cs", False),
|
||||
(279, "utf8mb4", "utf8mb4_de_pb_0900_as_cs", False),
|
||||
(280, "utf8mb4", "utf8mb4_is_0900_as_cs", False),
|
||||
(281, "utf8mb4", "utf8mb4_lv_0900_as_cs", False),
|
||||
(282, "utf8mb4", "utf8mb4_ro_0900_as_cs", False),
|
||||
(283, "utf8mb4", "utf8mb4_sl_0900_as_cs", False),
|
||||
(284, "utf8mb4", "utf8mb4_pl_0900_as_cs", False),
|
||||
(285, "utf8mb4", "utf8mb4_et_0900_as_cs", False),
|
||||
(286, "utf8mb4", "utf8mb4_es_0900_as_cs", False),
|
||||
(287, "utf8mb4", "utf8mb4_sv_0900_as_cs", False),
|
||||
(288, "utf8mb4", "utf8mb4_tr_0900_as_cs", False),
|
||||
(289, "utf8mb4", "utf8mb4_cs_0900_as_cs", False),
|
||||
(290, "utf8mb4", "utf8mb4_da_0900_as_cs", False),
|
||||
(291, "utf8mb4", "utf8mb4_lt_0900_as_cs", False),
|
||||
(292, "utf8mb4", "utf8mb4_sk_0900_as_cs", False),
|
||||
(293, "utf8mb4", "utf8mb4_es_trad_0900_as_cs", False),
|
||||
(294, "utf8mb4", "utf8mb4_la_0900_as_cs", False),
|
||||
(296, "utf8mb4", "utf8mb4_eo_0900_as_cs", False),
|
||||
(297, "utf8mb4", "utf8mb4_hu_0900_as_cs", False),
|
||||
(298, "utf8mb4", "utf8mb4_hr_0900_as_cs", False),
|
||||
(300, "utf8mb4", "utf8mb4_vi_0900_as_cs", False),
|
||||
(303, "utf8mb4", "utf8mb4_ja_0900_as_cs", False),
|
||||
(304, "utf8mb4", "utf8mb4_ja_0900_as_cs_ks", False),
|
||||
(305, "utf8mb4", "utf8mb4_0900_as_ci", False),
|
||||
(306, "utf8mb4", "utf8mb4_ru_0900_ai_ci", False),
|
||||
(307, "utf8mb4", "utf8mb4_ru_0900_as_cs", False),
|
||||
(308, "utf8mb4", "utf8mb4_zh_0900_as_cs", False),
|
||||
(309, "utf8mb4", "utf8mb4_0900_bin", False),
|
||||
(310, "utf8mb4", "utf8mb4_nb_0900_ai_ci", False),
|
||||
(311, "utf8mb4", "utf8mb4_nb_0900_as_cs", False),
|
||||
(312, "utf8mb4", "utf8mb4_nn_0900_ai_ci", False),
|
||||
(313, "utf8mb4", "utf8mb4_nn_0900_as_cs", False),
|
||||
(314, "utf8mb4", "utf8mb4_sr_latn_0900_ai_ci", False),
|
||||
(315, "utf8mb4", "utf8mb4_sr_latn_0900_as_cs", False),
|
||||
(316, "utf8mb4", "utf8mb4_bs_0900_ai_ci", False),
|
||||
(317, "utf8mb4", "utf8mb4_bs_0900_as_cs", False),
|
||||
(318, "utf8mb4", "utf8mb4_bg_0900_ai_ci", False),
|
||||
(319, "utf8mb4", "utf8mb4_bg_0900_as_cs", False),
|
||||
(320, "utf8mb4", "utf8mb4_gl_0900_ai_ci", False),
|
||||
(321, "utf8mb4", "utf8mb4_gl_0900_as_cs", False),
|
||||
(322, "utf8mb4", "utf8mb4_mn_cyrl_0900_ai_ci", False),
|
||||
(323, "utf8mb4", "utf8mb4_mn_cyrl_0900_as_cs", False),
|
||||
)
|
||||
|
||||
MYSQL_5_CHARSETS = (
|
||||
(1, "big5", "big5_chinese_ci", True),
|
||||
(2, "latin2", "latin2_czech_cs", False),
|
||||
(3, "dec8", "dec8_swedish_ci", True),
|
||||
(4, "cp850", "cp850_general_ci", True),
|
||||
(5, "latin1", "latin1_german1_ci", False),
|
||||
(6, "hp8", "hp8_english_ci", True),
|
||||
(7, "koi8r", "koi8r_general_ci", True),
|
||||
(8, "latin1", "latin1_swedish_ci", True),
|
||||
(9, "latin2", "latin2_general_ci", True),
|
||||
(10, "swe7", "swe7_swedish_ci", True),
|
||||
(11, "ascii", "ascii_general_ci", True),
|
||||
(12, "ujis", "ujis_japanese_ci", True),
|
||||
(13, "sjis", "sjis_japanese_ci", True),
|
||||
(14, "cp1251", "cp1251_bulgarian_ci", False),
|
||||
(15, "latin1", "latin1_danish_ci", False),
|
||||
(16, "hebrew", "hebrew_general_ci", True),
|
||||
(18, "tis620", "tis620_thai_ci", True),
|
||||
(19, "euckr", "euckr_korean_ci", True),
|
||||
(20, "latin7", "latin7_estonian_cs", False),
|
||||
(21, "latin2", "latin2_hungarian_ci", False),
|
||||
(22, "koi8u", "koi8u_general_ci", True),
|
||||
(23, "cp1251", "cp1251_ukrainian_ci", False),
|
||||
(24, "gb2312", "gb2312_chinese_ci", True),
|
||||
(25, "greek", "greek_general_ci", True),
|
||||
(26, "cp1250", "cp1250_general_ci", True),
|
||||
(27, "latin2", "latin2_croatian_ci", False),
|
||||
(28, "gbk", "gbk_chinese_ci", True),
|
||||
(29, "cp1257", "cp1257_lithuanian_ci", False),
|
||||
(30, "latin5", "latin5_turkish_ci", True),
|
||||
(31, "latin1", "latin1_german2_ci", False),
|
||||
(32, "armscii8", "armscii8_general_ci", True),
|
||||
(33, "utf8", "utf8_general_ci", True),
|
||||
(34, "cp1250", "cp1250_czech_cs", False),
|
||||
(35, "ucs2", "ucs2_general_ci", True),
|
||||
(36, "cp866", "cp866_general_ci", True),
|
||||
(37, "keybcs2", "keybcs2_general_ci", True),
|
||||
(38, "macce", "macce_general_ci", True),
|
||||
(39, "macroman", "macroman_general_ci", True),
|
||||
(40, "cp852", "cp852_general_ci", True),
|
||||
(41, "latin7", "latin7_general_ci", True),
|
||||
(42, "latin7", "latin7_general_cs", False),
|
||||
(43, "macce", "macce_bin", False),
|
||||
(44, "cp1250", "cp1250_croatian_ci", False),
|
||||
(45, "utf8mb4", "utf8mb4_general_ci", True),
|
||||
(46, "utf8mb4", "utf8mb4_bin", False),
|
||||
(47, "latin1", "latin1_bin", False),
|
||||
(48, "latin1", "latin1_general_ci", False),
|
||||
(49, "latin1", "latin1_general_cs", False),
|
||||
(50, "cp1251", "cp1251_bin", False),
|
||||
(51, "cp1251", "cp1251_general_ci", True),
|
||||
(52, "cp1251", "cp1251_general_cs", False),
|
||||
(53, "macroman", "macroman_bin", False),
|
||||
(54, "utf16", "utf16_general_ci", True),
|
||||
(55, "utf16", "utf16_bin", False),
|
||||
(56, "utf16le", "utf16le_general_ci", True),
|
||||
(57, "cp1256", "cp1256_general_ci", True),
|
||||
(58, "cp1257", "cp1257_bin", False),
|
||||
(59, "cp1257", "cp1257_general_ci", True),
|
||||
(60, "utf32", "utf32_general_ci", True),
|
||||
(61, "utf32", "utf32_bin", False),
|
||||
(62, "utf16le", "utf16le_bin", False),
|
||||
(63, "binary", "binary", True),
|
||||
(64, "armscii8", "armscii8_bin", False),
|
||||
(65, "ascii", "ascii_bin", False),
|
||||
(66, "cp1250", "cp1250_bin", False),
|
||||
(67, "cp1256", "cp1256_bin", False),
|
||||
(68, "cp866", "cp866_bin", False),
|
||||
(69, "dec8", "dec8_bin", False),
|
||||
(70, "greek", "greek_bin", False),
|
||||
(71, "hebrew", "hebrew_bin", False),
|
||||
(72, "hp8", "hp8_bin", False),
|
||||
(73, "keybcs2", "keybcs2_bin", False),
|
||||
(74, "koi8r", "koi8r_bin", False),
|
||||
(75, "koi8u", "koi8u_bin", False),
|
||||
(77, "latin2", "latin2_bin", False),
|
||||
(78, "latin5", "latin5_bin", False),
|
||||
(79, "latin7", "latin7_bin", False),
|
||||
(80, "cp850", "cp850_bin", False),
|
||||
(81, "cp852", "cp852_bin", False),
|
||||
(82, "swe7", "swe7_bin", False),
|
||||
(83, "utf8", "utf8_bin", False),
|
||||
(84, "big5", "big5_bin", False),
|
||||
(85, "euckr", "euckr_bin", False),
|
||||
(86, "gb2312", "gb2312_bin", False),
|
||||
(87, "gbk", "gbk_bin", False),
|
||||
(88, "sjis", "sjis_bin", False),
|
||||
(89, "tis620", "tis620_bin", False),
|
||||
(90, "ucs2", "ucs2_bin", False),
|
||||
(91, "ujis", "ujis_bin", False),
|
||||
(92, "geostd8", "geostd8_general_ci", True),
|
||||
(93, "geostd8", "geostd8_bin", False),
|
||||
(94, "latin1", "latin1_spanish_ci", False),
|
||||
(95, "cp932", "cp932_japanese_ci", True),
|
||||
(96, "cp932", "cp932_bin", False),
|
||||
(97, "eucjpms", "eucjpms_japanese_ci", True),
|
||||
(98, "eucjpms", "eucjpms_bin", False),
|
||||
(99, "cp1250", "cp1250_polish_ci", False),
|
||||
(101, "utf16", "utf16_unicode_ci", False),
|
||||
(102, "utf16", "utf16_icelandic_ci", False),
|
||||
(103, "utf16", "utf16_latvian_ci", False),
|
||||
(104, "utf16", "utf16_romanian_ci", False),
|
||||
(105, "utf16", "utf16_slovenian_ci", False),
|
||||
(106, "utf16", "utf16_polish_ci", False),
|
||||
(107, "utf16", "utf16_estonian_ci", False),
|
||||
(108, "utf16", "utf16_spanish_ci", False),
|
||||
(109, "utf16", "utf16_swedish_ci", False),
|
||||
(110, "utf16", "utf16_turkish_ci", False),
|
||||
(111, "utf16", "utf16_czech_ci", False),
|
||||
(112, "utf16", "utf16_danish_ci", False),
|
||||
(113, "utf16", "utf16_lithuanian_ci", False),
|
||||
(114, "utf16", "utf16_slovak_ci", False),
|
||||
(115, "utf16", "utf16_spanish2_ci", False),
|
||||
(116, "utf16", "utf16_roman_ci", False),
|
||||
(117, "utf16", "utf16_persian_ci", False),
|
||||
(118, "utf16", "utf16_esperanto_ci", False),
|
||||
(119, "utf16", "utf16_hungarian_ci", False),
|
||||
(120, "utf16", "utf16_sinhala_ci", False),
|
||||
(121, "utf16", "utf16_german2_ci", False),
|
||||
(122, "utf16", "utf16_croatian_ci", False),
|
||||
(123, "utf16", "utf16_unicode_520_ci", False),
|
||||
(124, "utf16", "utf16_vietnamese_ci", False),
|
||||
(128, "ucs2", "ucs2_unicode_ci", False),
|
||||
(129, "ucs2", "ucs2_icelandic_ci", False),
|
||||
(130, "ucs2", "ucs2_latvian_ci", False),
|
||||
(131, "ucs2", "ucs2_romanian_ci", False),
|
||||
(132, "ucs2", "ucs2_slovenian_ci", False),
|
||||
(133, "ucs2", "ucs2_polish_ci", False),
|
||||
(134, "ucs2", "ucs2_estonian_ci", False),
|
||||
(135, "ucs2", "ucs2_spanish_ci", False),
|
||||
(136, "ucs2", "ucs2_swedish_ci", False),
|
||||
(137, "ucs2", "ucs2_turkish_ci", False),
|
||||
(138, "ucs2", "ucs2_czech_ci", False),
|
||||
(139, "ucs2", "ucs2_danish_ci", False),
|
||||
(140, "ucs2", "ucs2_lithuanian_ci", False),
|
||||
(141, "ucs2", "ucs2_slovak_ci", False),
|
||||
(142, "ucs2", "ucs2_spanish2_ci", False),
|
||||
(143, "ucs2", "ucs2_roman_ci", False),
|
||||
(144, "ucs2", "ucs2_persian_ci", False),
|
||||
(145, "ucs2", "ucs2_esperanto_ci", False),
|
||||
(146, "ucs2", "ucs2_hungarian_ci", False),
|
||||
(147, "ucs2", "ucs2_sinhala_ci", False),
|
||||
(148, "ucs2", "ucs2_german2_ci", False),
|
||||
(149, "ucs2", "ucs2_croatian_ci", False),
|
||||
(150, "ucs2", "ucs2_unicode_520_ci", False),
|
||||
(151, "ucs2", "ucs2_vietnamese_ci", False),
|
||||
(159, "ucs2", "ucs2_general_mysql500_ci", False),
|
||||
(160, "utf32", "utf32_unicode_ci", False),
|
||||
(161, "utf32", "utf32_icelandic_ci", False),
|
||||
(162, "utf32", "utf32_latvian_ci", False),
|
||||
(163, "utf32", "utf32_romanian_ci", False),
|
||||
(164, "utf32", "utf32_slovenian_ci", False),
|
||||
(165, "utf32", "utf32_polish_ci", False),
|
||||
(166, "utf32", "utf32_estonian_ci", False),
|
||||
(167, "utf32", "utf32_spanish_ci", False),
|
||||
(168, "utf32", "utf32_swedish_ci", False),
|
||||
(169, "utf32", "utf32_turkish_ci", False),
|
||||
(170, "utf32", "utf32_czech_ci", False),
|
||||
(171, "utf32", "utf32_danish_ci", False),
|
||||
(172, "utf32", "utf32_lithuanian_ci", False),
|
||||
(173, "utf32", "utf32_slovak_ci", False),
|
||||
(174, "utf32", "utf32_spanish2_ci", False),
|
||||
(175, "utf32", "utf32_roman_ci", False),
|
||||
(176, "utf32", "utf32_persian_ci", False),
|
||||
(177, "utf32", "utf32_esperanto_ci", False),
|
||||
(178, "utf32", "utf32_hungarian_ci", False),
|
||||
(179, "utf32", "utf32_sinhala_ci", False),
|
||||
(180, "utf32", "utf32_german2_ci", False),
|
||||
(181, "utf32", "utf32_croatian_ci", False),
|
||||
(182, "utf32", "utf32_unicode_520_ci", False),
|
||||
(183, "utf32", "utf32_vietnamese_ci", False),
|
||||
(192, "utf8", "utf8_unicode_ci", False),
|
||||
(193, "utf8", "utf8_icelandic_ci", False),
|
||||
(194, "utf8", "utf8_latvian_ci", False),
|
||||
(195, "utf8", "utf8_romanian_ci", False),
|
||||
(196, "utf8", "utf8_slovenian_ci", False),
|
||||
(197, "utf8", "utf8_polish_ci", False),
|
||||
(198, "utf8", "utf8_estonian_ci", False),
|
||||
(199, "utf8", "utf8_spanish_ci", False),
|
||||
(200, "utf8", "utf8_swedish_ci", False),
|
||||
(201, "utf8", "utf8_turkish_ci", False),
|
||||
(202, "utf8", "utf8_czech_ci", False),
|
||||
(203, "utf8", "utf8_danish_ci", False),
|
||||
(204, "utf8", "utf8_lithuanian_ci", False),
|
||||
(205, "utf8", "utf8_slovak_ci", False),
|
||||
(206, "utf8", "utf8_spanish2_ci", False),
|
||||
(207, "utf8", "utf8_roman_ci", False),
|
||||
(208, "utf8", "utf8_persian_ci", False),
|
||||
(209, "utf8", "utf8_esperanto_ci", False),
|
||||
(210, "utf8", "utf8_hungarian_ci", False),
|
||||
(211, "utf8", "utf8_sinhala_ci", False),
|
||||
(212, "utf8", "utf8_german2_ci", False),
|
||||
(213, "utf8", "utf8_croatian_ci", False),
|
||||
(214, "utf8", "utf8_unicode_520_ci", False),
|
||||
(215, "utf8", "utf8_vietnamese_ci", False),
|
||||
(223, "utf8", "utf8_general_mysql500_ci", False),
|
||||
(224, "utf8mb4", "utf8mb4_unicode_ci", False),
|
||||
(225, "utf8mb4", "utf8mb4_icelandic_ci", False),
|
||||
(226, "utf8mb4", "utf8mb4_latvian_ci", False),
|
||||
(227, "utf8mb4", "utf8mb4_romanian_ci", False),
|
||||
(228, "utf8mb4", "utf8mb4_slovenian_ci", False),
|
||||
(229, "utf8mb4", "utf8mb4_polish_ci", False),
|
||||
(230, "utf8mb4", "utf8mb4_estonian_ci", False),
|
||||
(231, "utf8mb4", "utf8mb4_spanish_ci", False),
|
||||
(232, "utf8mb4", "utf8mb4_swedish_ci", False),
|
||||
(233, "utf8mb4", "utf8mb4_turkish_ci", False),
|
||||
(234, "utf8mb4", "utf8mb4_czech_ci", False),
|
||||
(235, "utf8mb4", "utf8mb4_danish_ci", False),
|
||||
(236, "utf8mb4", "utf8mb4_lithuanian_ci", False),
|
||||
(237, "utf8mb4", "utf8mb4_slovak_ci", False),
|
||||
(238, "utf8mb4", "utf8mb4_spanish2_ci", False),
|
||||
(239, "utf8mb4", "utf8mb4_roman_ci", False),
|
||||
(240, "utf8mb4", "utf8mb4_persian_ci", False),
|
||||
(241, "utf8mb4", "utf8mb4_esperanto_ci", False),
|
||||
(242, "utf8mb4", "utf8mb4_hungarian_ci", False),
|
||||
(243, "utf8mb4", "utf8mb4_sinhala_ci", False),
|
||||
(244, "utf8mb4", "utf8mb4_german2_ci", False),
|
||||
(245, "utf8mb4", "utf8mb4_croatian_ci", False),
|
||||
(246, "utf8mb4", "utf8mb4_unicode_520_ci", False),
|
||||
(247, "utf8mb4", "utf8mb4_vietnamese_ci", False),
|
||||
(248, "gb18030", "gb18030_chinese_ci", True),
|
||||
(249, "gb18030", "gb18030_bin", False),
|
||||
(250, "gb18030", "gb18030_unicode_520_ci", False),
|
||||
)
|
||||
|
||||
charsets = Charsets()
|
||||
1510
myenv/lib/python3.11/site-packages/mysql/connector/aio/connection.py
Normal file
1510
myenv/lib/python3.11/site-packages/mysql/connector/aio/connection.py
Normal file
File diff suppressed because it is too large
Load Diff
1585
myenv/lib/python3.11/site-packages/mysql/connector/aio/cursor.py
Normal file
1585
myenv/lib/python3.11/site-packages/mysql/connector/aio/cursor.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,33 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Setup of the `mysql.connector.aio` logger."""
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("mysql.connector.aio")
|
||||
@@ -0,0 +1,761 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
||||
"""Module implementing low-level socket communication with MySQL servers."""
|
||||
|
||||
|
||||
__all__ = ["MySQLTcpSocket", "MySQLUnixSocket"]
|
||||
|
||||
import asyncio
|
||||
import struct
|
||||
import zlib
|
||||
|
||||
try:
|
||||
import ssl
|
||||
|
||||
TLS_VERSIONS = {
|
||||
"TLSv1": ssl.PROTOCOL_TLSv1,
|
||||
"TLSv1.1": ssl.PROTOCOL_TLSv1_1,
|
||||
"TLSv1.2": ssl.PROTOCOL_TLSv1_2,
|
||||
"TLSv1.3": ssl.PROTOCOL_TLS,
|
||||
}
|
||||
except ImportError:
|
||||
ssl = None
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from typing import Any, Deque, List, Optional, Tuple
|
||||
|
||||
from ..errors import (
|
||||
InterfaceError,
|
||||
NotSupportedError,
|
||||
OperationalError,
|
||||
ProgrammingError,
|
||||
ReadTimeoutError,
|
||||
WriteTimeoutError,
|
||||
)
|
||||
from ..network import (
|
||||
COMPRESSED_PACKET_HEADER_LENGTH,
|
||||
MAX_PAYLOAD_LENGTH,
|
||||
MIN_COMPRESS_LENGTH,
|
||||
PACKET_HEADER_LENGTH,
|
||||
)
|
||||
from .utils import StreamWriter, open_connection
|
||||
|
||||
|
||||
def _strioerror(err: IOError) -> str:
|
||||
"""Reformat the IOError error message.
|
||||
|
||||
This function reformats the IOError error message.
|
||||
"""
|
||||
return str(err) if not err.errno else f"{err.errno} {err.strerror}"
|
||||
|
||||
|
||||
class NetworkBroker(ABC):
|
||||
"""Broker class interface.
|
||||
|
||||
The network object is a broker used as a delegate by a socket object. Whenever the
|
||||
socket wants to deliver or get packets to or from the MySQL server it needs to rely
|
||||
on its network broker (netbroker).
|
||||
|
||||
The netbroker sends `payloads` and receives `packets`.
|
||||
|
||||
A packet is a bytes sequence, it has a header and body (referred to as payload).
|
||||
The first `PACKET_HEADER_LENGTH` or `COMPRESSED_PACKET_HEADER_LENGTH`
|
||||
(as appropriate) bytes correspond to the `header`, the remaining ones represent the
|
||||
`payload`.
|
||||
|
||||
The maximum payload length allowed to be sent per packet to the server is
|
||||
`MAX_PAYLOAD_LENGTH`. When `send` is called with a payload whose length is greater
|
||||
than `MAX_PAYLOAD_LENGTH` the netbroker breaks it down into packets, so the caller
|
||||
of `send` can provide payloads of arbitrary length.
|
||||
|
||||
Finally, data received by the netbroker comes directly from the server, expect to
|
||||
get a packet for each call to `recv`. The received packet contains a header and
|
||||
payload, the latter respecting `MAX_PAYLOAD_LENGTH`.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def write(
|
||||
self,
|
||||
writer: StreamWriter,
|
||||
address: str,
|
||||
payload: bytes,
|
||||
packet_number: Optional[int] = None,
|
||||
compressed_packet_number: Optional[int] = None,
|
||||
write_timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Send `payload` to the MySQL server.
|
||||
|
||||
If provided a payload whose length is greater than `MAX_PAYLOAD_LENGTH`, it is
|
||||
broken down into packets.
|
||||
|
||||
Args:
|
||||
sock: Object holding the socket connection.
|
||||
address: Socket's location.
|
||||
payload: Packet's body to send.
|
||||
packet_number: Sequence id (packet ID) to attach to the header when sending
|
||||
plain packets.
|
||||
compressed_packet_number: Same as `packet_number` but used when sending
|
||||
compressed packets.
|
||||
write_timeout: Timeout in seconds before which sending a packet to the server
|
||||
should finish else WriteTimeoutError is raised.
|
||||
|
||||
|
||||
Raises:
|
||||
:class:`OperationalError`: If something goes wrong while sending packets to
|
||||
the MySQL server.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def read(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
address: str,
|
||||
read_timeout: Optional[int] = None,
|
||||
) -> bytearray:
|
||||
"""Get the next available packet from the MySQL server.
|
||||
|
||||
Args:
|
||||
sock: Object holding the socket connection.
|
||||
address: Socket's location.
|
||||
read_timeout: Timeout in seconds before which reading a packet from the server
|
||||
should finish.
|
||||
|
||||
Returns:
|
||||
packet: A packet from the MySQL server.
|
||||
|
||||
Raises:
|
||||
:class:`OperationalError`: If something goes wrong while receiving packets
|
||||
from the MySQL server.
|
||||
:class:`ReadTimeoutError`: If the time to receive a packet from the server takes
|
||||
longer than `read_timeout`.
|
||||
:class:`InterfaceError`: If something goes wrong while receiving packets
|
||||
from the MySQL server.
|
||||
"""
|
||||
|
||||
|
||||
class NetworkBrokerPlain(NetworkBroker):
|
||||
"""Broker class for MySQL socket communication."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pktnr: int = -1 # packet number
|
||||
|
||||
@staticmethod
|
||||
def get_header(pkt: bytes) -> Tuple[int, int]:
|
||||
"""Recover the header information from a packet."""
|
||||
if len(pkt) < PACKET_HEADER_LENGTH:
|
||||
raise ValueError("Can't recover header info from an incomplete packet")
|
||||
|
||||
pll, seqid = (
|
||||
struct.unpack("<I", pkt[0:3] + b"\x00")[0],
|
||||
pkt[3],
|
||||
)
|
||||
# payload length, sequence id
|
||||
return pll, seqid
|
||||
|
||||
def _set_next_pktnr(self, next_id: Optional[int] = None) -> None:
|
||||
"""Set the given packet id, if any, else increment packet id."""
|
||||
if next_id is None:
|
||||
self._pktnr += 1
|
||||
else:
|
||||
self._pktnr = next_id
|
||||
self._pktnr %= 256
|
||||
|
||||
async def _write_pkt(
|
||||
self,
|
||||
writer: StreamWriter,
|
||||
address: str,
|
||||
pkt: bytes,
|
||||
) -> None:
|
||||
"""Write packet to the comm channel."""
|
||||
try:
|
||||
writer.write(pkt)
|
||||
await writer.drain()
|
||||
except IOError as err:
|
||||
raise OperationalError(
|
||||
errno=2055, values=(address, _strioerror(err))
|
||||
) from err
|
||||
except AttributeError as err:
|
||||
raise OperationalError(errno=2006) from err
|
||||
|
||||
async def _read_chunk(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
size: int = 0,
|
||||
read_timeout: Optional[int] = None,
|
||||
) -> bytearray:
|
||||
"""Read `size` bytes from the comm channel."""
|
||||
try:
|
||||
pkt = bytearray(b"")
|
||||
while len(pkt) < size:
|
||||
chunk = await asyncio.wait_for(
|
||||
reader.read(size - len(pkt)), read_timeout
|
||||
)
|
||||
if not chunk:
|
||||
raise InterfaceError(errno=2013)
|
||||
pkt += chunk
|
||||
return pkt
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError) as err:
|
||||
raise ReadTimeoutError(errno=3024) from err
|
||||
|
||||
async def write(
|
||||
self,
|
||||
writer: StreamWriter,
|
||||
address: str,
|
||||
payload: bytes,
|
||||
packet_number: Optional[int] = None,
|
||||
compressed_packet_number: Optional[int] = None,
|
||||
write_timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Send payload to the MySQL server.
|
||||
|
||||
If provided a payload whose length is greater than `MAX_PAYLOAD_LENGTH`, it is
|
||||
broken down into packets.
|
||||
"""
|
||||
self._set_next_pktnr(packet_number)
|
||||
# If the payload is larger than or equal to MAX_PAYLOAD_LENGTH the length is
|
||||
# set to 2^24 - 1 (ff ff ff) and additional packets are sent with the rest of
|
||||
# the payload until the payload of a packet is less than MAX_PAYLOAD_LENGTH.
|
||||
offset = 0
|
||||
try:
|
||||
for _ in range(len(payload) // MAX_PAYLOAD_LENGTH):
|
||||
# payload_len, sequence_id, payload
|
||||
await asyncio.wait_for(
|
||||
self._write_pkt(
|
||||
writer,
|
||||
address,
|
||||
b"\xff" * 3
|
||||
+ struct.pack("<B", self._pktnr)
|
||||
+ payload[offset : offset + MAX_PAYLOAD_LENGTH],
|
||||
),
|
||||
write_timeout,
|
||||
)
|
||||
self._set_next_pktnr()
|
||||
offset += MAX_PAYLOAD_LENGTH
|
||||
await asyncio.wait_for(
|
||||
self._write_pkt(
|
||||
writer,
|
||||
address,
|
||||
struct.pack("<I", len(payload) - offset)[0:3]
|
||||
+ struct.pack("<B", self._pktnr)
|
||||
+ payload[offset:],
|
||||
),
|
||||
write_timeout,
|
||||
)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError) as err:
|
||||
raise WriteTimeoutError(errno=3024) from err
|
||||
|
||||
async def read(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
address: str,
|
||||
read_timeout: Optional[int] = None,
|
||||
) -> bytearray:
|
||||
"""Receive `one` packet from the MySQL server."""
|
||||
try:
|
||||
# Read the header of the MySQL packet.
|
||||
header = await self._read_chunk(reader, PACKET_HEADER_LENGTH, read_timeout)
|
||||
|
||||
# Pull the payload length and sequence id.
|
||||
payload_len, self._pktnr = self.get_header(header)
|
||||
|
||||
# Read the payload, and return packet.
|
||||
return header + await self._read_chunk(reader, payload_len, read_timeout)
|
||||
except IOError as err:
|
||||
raise OperationalError(
|
||||
errno=2055, values=(address, _strioerror(err))
|
||||
) from err
|
||||
|
||||
|
||||
class NetworkBrokerCompressed(NetworkBrokerPlain):
|
||||
"""Broker class for MySQL socket communication."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._compressed_pktnr = -1
|
||||
self._queue_read: Deque[bytearray] = deque()
|
||||
|
||||
@staticmethod
|
||||
def _prepare_packets(payload: bytes, pktnr: int) -> List[bytes]:
|
||||
"""Prepare a payload for sending to the MySQL server."""
|
||||
offset = 0
|
||||
pkts = []
|
||||
|
||||
# If the payload is larger than or equal to MAX_PAYLOAD_LENGTH the length is
|
||||
# set to 2^24 - 1 (ff ff ff) and additional packets are sent with the rest of
|
||||
# the payload until the payload of a packet is less than MAX_PAYLOAD_LENGTH.
|
||||
for _ in range(len(payload) // MAX_PAYLOAD_LENGTH):
|
||||
# payload length + sequence id + payload
|
||||
pkts.append(
|
||||
b"\xff" * 3
|
||||
+ struct.pack("<B", pktnr)
|
||||
+ payload[offset : offset + MAX_PAYLOAD_LENGTH]
|
||||
)
|
||||
pktnr = (pktnr + 1) % 256
|
||||
offset += MAX_PAYLOAD_LENGTH
|
||||
pkts.append(
|
||||
struct.pack("<I", len(payload) - offset)[0:3]
|
||||
+ struct.pack("<B", pktnr)
|
||||
+ payload[offset:]
|
||||
)
|
||||
return pkts
|
||||
|
||||
@staticmethod
|
||||
def get_header(pkt: bytes) -> Tuple[int, int, int]: # type: ignore[override]
|
||||
"""Recover the header information from a packet."""
|
||||
if len(pkt) < COMPRESSED_PACKET_HEADER_LENGTH:
|
||||
raise ValueError("Can't recover header info from an incomplete packet")
|
||||
|
||||
compressed_pll, seqid, uncompressed_pll = (
|
||||
struct.unpack("<I", pkt[0:3] + b"\x00")[0],
|
||||
pkt[3],
|
||||
struct.unpack("<I", pkt[4:7] + b"\x00")[0],
|
||||
)
|
||||
# compressed payload length, sequence id, uncompressed payload length
|
||||
return compressed_pll, seqid, uncompressed_pll
|
||||
|
||||
def _set_next_compressed_pktnr(self, next_id: Optional[int] = None) -> None:
|
||||
"""Set the given packet id, if any, else increment packet id."""
|
||||
if next_id is None:
|
||||
self._compressed_pktnr += 1
|
||||
else:
|
||||
self._compressed_pktnr = next_id
|
||||
self._compressed_pktnr %= 256
|
||||
|
||||
async def _write_pkt(
|
||||
self,
|
||||
writer: StreamWriter,
|
||||
address: str,
|
||||
pkt: bytes,
|
||||
) -> None:
|
||||
"""Compress packet and write it to the comm channel."""
|
||||
compressed_pkt = zlib.compress(pkt)
|
||||
pkt = (
|
||||
struct.pack("<I", len(compressed_pkt))[0:3]
|
||||
+ struct.pack("<B", self._compressed_pktnr)
|
||||
+ struct.pack("<I", len(pkt))[0:3]
|
||||
+ compressed_pkt
|
||||
)
|
||||
return await super()._write_pkt(writer, address, pkt)
|
||||
|
||||
async def write(
|
||||
self,
|
||||
writer: StreamWriter,
|
||||
address: str,
|
||||
payload: bytes,
|
||||
packet_number: Optional[int] = None,
|
||||
compressed_packet_number: Optional[int] = None,
|
||||
write_timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Send `payload` as compressed packets to the MySQL server.
|
||||
|
||||
If provided a payload whose length is greater than `MAX_PAYLOAD_LENGTH`, it is
|
||||
broken down into packets.
|
||||
"""
|
||||
# Get next packet numbers.
|
||||
self._set_next_pktnr(packet_number)
|
||||
self._set_next_compressed_pktnr(compressed_packet_number)
|
||||
try:
|
||||
payload_prep = bytearray(b"").join(
|
||||
self._prepare_packets(payload, self._pktnr)
|
||||
)
|
||||
if len(payload) >= MAX_PAYLOAD_LENGTH - PACKET_HEADER_LENGTH:
|
||||
# Sending a MySQL payload of the size greater or equal to 2^24 - 5 via
|
||||
# compression leads to at least one extra compressed packet WHY? let's say
|
||||
# len(payload) is MAX_PAYLOAD_LENGTH - 3; when preparing the payload, a
|
||||
# header of size PACKET_HEADER_LENGTH is pre-appended to the payload.
|
||||
# This means that len(payload_prep) is
|
||||
# MAX_PAYLOAD_LENGTH - 3 + PACKET_HEADER_LENGTH = MAX_PAYLOAD_LENGTH + 1
|
||||
# surpassing the maximum allowed payload size per packet.
|
||||
offset = 0
|
||||
|
||||
# Send several MySQL packets.
|
||||
for _ in range(len(payload_prep) // MAX_PAYLOAD_LENGTH):
|
||||
await asyncio.wait_for(
|
||||
self._write_pkt(
|
||||
writer,
|
||||
address,
|
||||
payload_prep[offset : offset + MAX_PAYLOAD_LENGTH],
|
||||
),
|
||||
write_timeout,
|
||||
)
|
||||
self._set_next_compressed_pktnr()
|
||||
offset += MAX_PAYLOAD_LENGTH
|
||||
await asyncio.wait_for(
|
||||
self._write_pkt(writer, address, payload_prep[offset:]),
|
||||
write_timeout,
|
||||
)
|
||||
else:
|
||||
# Send one MySQL packet.
|
||||
# For small packets it may be too costly to compress the packet.
|
||||
# Usually payloads less than 50 bytes (MIN_COMPRESS_LENGTH) aren't
|
||||
# compressed (see MySQL source code Documentation).
|
||||
if len(payload) > MIN_COMPRESS_LENGTH:
|
||||
# Perform compression.
|
||||
await asyncio.wait_for(
|
||||
self._write_pkt(writer, address, payload_prep), write_timeout
|
||||
)
|
||||
else:
|
||||
# Skip compression.
|
||||
await asyncio.wait_for(
|
||||
super()._write_pkt(
|
||||
writer,
|
||||
address,
|
||||
struct.pack("<I", len(payload_prep))[0:3]
|
||||
+ struct.pack("<B", self._compressed_pktnr)
|
||||
+ struct.pack("<I", 0)[0:3]
|
||||
+ payload_prep,
|
||||
),
|
||||
write_timeout,
|
||||
)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError) as err:
|
||||
raise WriteTimeoutError(errno=3024) from err
|
||||
|
||||
async def _read_compressed_pkt(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
compressed_pll: int,
|
||||
read_timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Handle reading of a compressed packet."""
|
||||
# compressed_pll stands for compressed payload length.
|
||||
pkt = bytearray(
|
||||
zlib.decompress(
|
||||
await super()._read_chunk(reader, compressed_pll, read_timeout)
|
||||
)
|
||||
)
|
||||
offset = 0
|
||||
while offset < len(pkt):
|
||||
# pll stands for payload length
|
||||
pll = struct.unpack(
|
||||
"<I", pkt[offset : offset + PACKET_HEADER_LENGTH - 1] + b"\x00"
|
||||
)[0]
|
||||
if PACKET_HEADER_LENGTH + pll > len(pkt) - offset:
|
||||
# More bytes need to be consumed.
|
||||
# Read the header of the next MySQL packet.
|
||||
header = await super()._read_chunk(
|
||||
reader, COMPRESSED_PACKET_HEADER_LENGTH, read_timeout
|
||||
)
|
||||
|
||||
# compressed payload length, sequence id, uncompressed payload length.
|
||||
(
|
||||
compressed_pll,
|
||||
self._compressed_pktnr,
|
||||
uncompressed_pll,
|
||||
) = self.get_header(header)
|
||||
compressed_pkt = await super()._read_chunk(
|
||||
reader, compressed_pll, read_timeout
|
||||
)
|
||||
|
||||
# Recalling that if uncompressed payload length == 0, the packet comes
|
||||
# in uncompressed, so no decompression is needed.
|
||||
pkt += (
|
||||
compressed_pkt
|
||||
if uncompressed_pll == 0
|
||||
else zlib.decompress(compressed_pkt)
|
||||
)
|
||||
|
||||
self._queue_read.append(pkt[offset : offset + PACKET_HEADER_LENGTH + pll])
|
||||
offset += PACKET_HEADER_LENGTH + pll
|
||||
|
||||
async def read(
|
||||
self,
|
||||
reader: asyncio.StreamReader,
|
||||
address: str,
|
||||
read_timeout: Optional[int] = None,
|
||||
) -> bytearray:
|
||||
"""Receive `one` or `several` packets from the MySQL server, enqueue them, and
|
||||
return the packet at the head.
|
||||
"""
|
||||
|
||||
if not self._queue_read:
|
||||
try:
|
||||
# Read the header of the next MySQL packet.
|
||||
header = await super()._read_chunk(
|
||||
reader, COMPRESSED_PACKET_HEADER_LENGTH, read_timeout
|
||||
)
|
||||
|
||||
# compressed payload length, sequence id, uncompressed payload length
|
||||
(
|
||||
compressed_pll,
|
||||
self._compressed_pktnr,
|
||||
uncompressed_pll,
|
||||
) = self.get_header(header)
|
||||
|
||||
if uncompressed_pll == 0:
|
||||
# Packet is not compressed, so just store it.
|
||||
self._queue_read.append(
|
||||
await super()._read_chunk(reader, compressed_pll, read_timeout)
|
||||
)
|
||||
else:
|
||||
# Packet comes in compressed, further action is needed.
|
||||
await self._read_compressed_pkt(
|
||||
reader, compressed_pll, read_timeout
|
||||
)
|
||||
except IOError as err:
|
||||
raise OperationalError(
|
||||
errno=2055, values=(address, _strioerror(err))
|
||||
) from err
|
||||
|
||||
if not self._queue_read:
|
||||
return None
|
||||
|
||||
pkt = self._queue_read.popleft()
|
||||
self._pktnr = pkt[3]
|
||||
|
||||
return pkt
|
||||
|
||||
|
||||
class MySQLSocket(ABC):
|
||||
"""MySQL socket communication interface.
|
||||
|
||||
Examples:
|
||||
Subclasses: network.MySQLTCPSocket and network.MySQLUnixSocket.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Network layer where transactions are made with plain (uncompressed) packets
|
||||
is enabled by default.
|
||||
"""
|
||||
self._reader: Optional[asyncio.StreamReader] = None
|
||||
self._writer: Optional[StreamWriter] = None
|
||||
self._connection_timeout: Optional[int] = None
|
||||
self._address: Optional[str] = None
|
||||
self._netbroker: NetworkBroker = NetworkBrokerPlain()
|
||||
self._is_connected: bool = False
|
||||
|
||||
@property
|
||||
def address(self) -> str:
|
||||
"""Socket location."""
|
||||
return self._address
|
||||
|
||||
@abstractmethod
|
||||
async def open_connection(self, **kwargs: Any) -> None:
|
||||
"""Open the socket."""
|
||||
|
||||
async def close_connection(self) -> None:
|
||||
"""Close the connection."""
|
||||
if self._writer:
|
||||
try:
|
||||
self._writer.close()
|
||||
# Without transport.abort(), an error is raised when using SSL
|
||||
if self._writer.transport is not None:
|
||||
self._writer.transport.abort()
|
||||
await self._writer.wait_closed()
|
||||
except Exception as _: # pylint: disable=broad-exception-caught)
|
||||
# we can ignore issues like ConnectionRefused or ConnectionAborted
|
||||
# as these instances might popup if the connection was closed due to timeout issues
|
||||
pass
|
||||
self._is_connected = False
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the socket is connected.
|
||||
|
||||
Return:
|
||||
bool: Returns `True` if the socket is connected to MySQL server.
|
||||
"""
|
||||
return self._is_connected
|
||||
|
||||
def set_connection_timeout(self, timeout: int) -> None:
|
||||
"""Set the connection timeout."""
|
||||
self._connection_timeout = timeout
|
||||
|
||||
def switch_to_compressed_mode(self) -> None:
|
||||
"""Enable network layer where transactions are made with compressed packets."""
|
||||
self._netbroker = NetworkBrokerCompressed()
|
||||
|
||||
async def switch_to_ssl(self, ssl_context: ssl.SSLContext) -> None:
|
||||
"""Upgrade an existing stream-based connection to TLS.
|
||||
|
||||
The `start_tls()` method from `asyncio.streams.StreamWriter` is only available
|
||||
in Python 3.11. This method is used as a workaround.
|
||||
|
||||
The MySQL TLS negotiation happens in the middle of the TCP connection.
|
||||
Therefore, passing a socket to open connection will cause it to negotiate
|
||||
TLS on an existing connection.
|
||||
|
||||
Args:
|
||||
ssl_context: The SSL Context to be used.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the transport does not expose the socket instance.
|
||||
"""
|
||||
# Ensure that self._writer is already created
|
||||
assert self._writer is not None
|
||||
|
||||
socket = self._writer.transport.get_extra_info("socket")
|
||||
if socket.family == 1: # socket.AF_UNIX
|
||||
raise ProgrammingError("SSL is not supported when using Unix sockets")
|
||||
|
||||
await self._writer.start_tls(ssl_context)
|
||||
|
||||
async def write(
|
||||
self,
|
||||
payload: bytes,
|
||||
packet_number: Optional[int] = None,
|
||||
compressed_packet_number: Optional[int] = None,
|
||||
write_timeout: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Send packets to the MySQL server."""
|
||||
await self._netbroker.write(
|
||||
self._writer,
|
||||
self.address,
|
||||
payload,
|
||||
packet_number=packet_number,
|
||||
compressed_packet_number=compressed_packet_number,
|
||||
write_timeout=write_timeout,
|
||||
)
|
||||
|
||||
async def read(self, read_timeout: Optional[int] = None) -> bytearray:
|
||||
"""Read packets from the MySQL server."""
|
||||
return await self._netbroker.read(self._reader, self.address, read_timeout)
|
||||
|
||||
def build_ssl_context(
|
||||
self,
|
||||
ssl_ca: Optional[str] = None,
|
||||
ssl_cert: Optional[str] = None,
|
||||
ssl_key: Optional[str] = None,
|
||||
ssl_verify_cert: Optional[bool] = False,
|
||||
ssl_verify_identity: Optional[bool] = False,
|
||||
tls_versions: Optional[List[str]] = [],
|
||||
tls_cipher_suites: Optional[List[str]] = [],
|
||||
) -> ssl.SSLContext:
|
||||
"""Build a SSLContext."""
|
||||
tls_version: Optional[str] = None
|
||||
|
||||
if not self._reader:
|
||||
raise InterfaceError(errno=2048)
|
||||
|
||||
if ssl is None:
|
||||
raise RuntimeError("Python installation has no SSL support")
|
||||
|
||||
try:
|
||||
if tls_versions:
|
||||
tls_versions.sort(reverse=True)
|
||||
tls_version = tls_versions[0]
|
||||
ssl_protocol = TLS_VERSIONS[tls_version]
|
||||
context = ssl.SSLContext(ssl_protocol)
|
||||
|
||||
if tls_version == "TLSv1.3":
|
||||
if "TLSv1.2" not in tls_versions:
|
||||
context.options |= ssl.OP_NO_TLSv1_2
|
||||
if "TLSv1.1" not in tls_versions:
|
||||
context.options |= ssl.OP_NO_TLSv1_1
|
||||
if "TLSv1" not in tls_versions:
|
||||
context.options |= ssl.OP_NO_TLSv1
|
||||
else:
|
||||
context = ssl.create_default_context()
|
||||
|
||||
context.check_hostname = ssl_verify_identity
|
||||
|
||||
if ssl_verify_cert:
|
||||
context.verify_mode = ssl.CERT_REQUIRED
|
||||
elif ssl_verify_identity:
|
||||
context.verify_mode = ssl.CERT_OPTIONAL
|
||||
else:
|
||||
context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
context.load_default_certs()
|
||||
|
||||
if ssl_ca:
|
||||
try:
|
||||
context.load_verify_locations(ssl_ca)
|
||||
except (IOError, ssl.SSLError) as err:
|
||||
raise InterfaceError(f"Invalid CA Certificate: {err}") from err
|
||||
if ssl_cert:
|
||||
try:
|
||||
context.load_cert_chain(ssl_cert, ssl_key)
|
||||
except (IOError, ssl.SSLError) as err:
|
||||
raise InterfaceError(f"Invalid Certificate/Key: {err}") from err
|
||||
|
||||
# TLSv1.3 ciphers cannot be disabled with `SSLContext.set_ciphers(...)`,
|
||||
# see https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers.
|
||||
if tls_cipher_suites and tls_version == "TLSv1.2":
|
||||
context.set_ciphers(":".join(tls_cipher_suites))
|
||||
|
||||
return context
|
||||
except NameError as err:
|
||||
raise NotSupportedError("Python installation has no SSL support") from err
|
||||
except (
|
||||
IOError,
|
||||
NotImplementedError,
|
||||
ssl.CertificateError,
|
||||
ssl.SSLError,
|
||||
) as err:
|
||||
raise InterfaceError(str(err)) from err
|
||||
|
||||
|
||||
class MySQLTcpSocket(MySQLSocket):
|
||||
"""MySQL socket class using TCP/IP.
|
||||
|
||||
Args:
|
||||
host: MySQL host name.
|
||||
port: MySQL port.
|
||||
force_ipv6: Force IPv6 usage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, host: str = "127.0.0.1", port: int = 3306, force_ipv6: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
self._host: str = host
|
||||
self._port: int = port
|
||||
self._force_ipv6: bool = force_ipv6
|
||||
self._address: str = f"{host}:{port}"
|
||||
|
||||
async def open_connection(self, **kwargs: Any) -> None:
|
||||
"""Open TCP/IP connection."""
|
||||
self._reader, self._writer = await open_connection(
|
||||
host=self._host, port=self._port, **kwargs
|
||||
)
|
||||
self._is_connected = True
|
||||
|
||||
|
||||
class MySQLUnixSocket(MySQLSocket):
|
||||
"""MySQL socket class using UNIX sockets.
|
||||
|
||||
Args:
|
||||
unix_socket: UNIX socket file path.
|
||||
"""
|
||||
|
||||
def __init__(self, unix_socket: str = "/tmp/mysql.sock"):
|
||||
super().__init__()
|
||||
self._address: str = unix_socket
|
||||
|
||||
async def open_connection(self, **kwargs: Any) -> None:
|
||||
"""Open UNIX socket connection."""
|
||||
(
|
||||
self._reader,
|
||||
self._writer,
|
||||
) = await asyncio.open_unix_connection( # type: ignore[assignment]
|
||||
path=self._address, **kwargs
|
||||
)
|
||||
self._is_connected = True
|
||||
@@ -0,0 +1,162 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Base Authentication Plugin class."""
|
||||
|
||||
__all__ = ["MySQLAuthPlugin", "get_auth_plugin"]
|
||||
|
||||
import importlib
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, Optional, Type
|
||||
|
||||
from mysql.connector.errors import NotSupportedError, ProgrammingError
|
||||
from mysql.connector.logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..network import MySQLSocket
|
||||
|
||||
DEFAULT_PLUGINS_PKG = "mysql.connector.aio.plugins"
|
||||
|
||||
|
||||
class MySQLAuthPlugin(ABC):
|
||||
"""Authorization plugin interface."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
ssl_enabled: bool = False,
|
||||
) -> None:
|
||||
"""Constructor."""
|
||||
self._username: str = "" if username is None else username
|
||||
self._password: str = "" if password is None else password
|
||||
self._ssl_enabled: bool = ssl_enabled
|
||||
|
||||
@property
|
||||
def ssl_enabled(self) -> bool:
|
||||
"""Signals whether or not SSL is enabled."""
|
||||
return self._ssl_enabled
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def requires_ssl(self) -> bool:
|
||||
"""Signals whether or not SSL is required."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Plugin official name."""
|
||||
|
||||
@abstractmethod
|
||||
def auth_response(self, auth_data: bytes, **kwargs: Any) -> Optional[bytes]:
|
||||
"""Make the client's authorization response.
|
||||
|
||||
Args:
|
||||
auth_data: Authorization data.
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Client's authorization response.
|
||||
"""
|
||||
|
||||
async def auth_more_response(
|
||||
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth more data` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Authentication method data (from a packet representing
|
||||
an `auth more data` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth communication.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def auth_switch_response(
|
||||
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth switch request` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Plugin provided data (extracted from a packet
|
||||
representing an `auth switch request` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth communication.
|
||||
"""
|
||||
|
||||
|
||||
@lru_cache(maxsize=10, typed=False)
|
||||
def get_auth_plugin(
|
||||
plugin_name: str,
|
||||
auth_plugin_class: Optional[str] = None,
|
||||
) -> Type[MySQLAuthPlugin]:
|
||||
"""Return authentication class based on plugin name
|
||||
|
||||
This function returns the class for the authentication plugin plugin_name.
|
||||
The returned class is a subclass of BaseAuthPlugin.
|
||||
|
||||
Args:
|
||||
plugin_name (str): Authentication plugin name.
|
||||
auth_plugin_class (str): Authentication plugin class name.
|
||||
|
||||
Raises:
|
||||
NotSupportedError: When plugin_name is not supported.
|
||||
|
||||
Returns:
|
||||
Subclass of `MySQLAuthPlugin`.
|
||||
"""
|
||||
package = DEFAULT_PLUGINS_PKG
|
||||
if plugin_name:
|
||||
try:
|
||||
logger.info("package: %s", package)
|
||||
logger.info("plugin_name: %s", plugin_name)
|
||||
plugin_module = importlib.import_module(f".{plugin_name}", package)
|
||||
if not auth_plugin_class or not hasattr(plugin_module, auth_plugin_class):
|
||||
auth_plugin_class = plugin_module.AUTHENTICATION_PLUGIN_CLASS
|
||||
logger.info("AUTHENTICATION_PLUGIN_CLASS: %s", auth_plugin_class)
|
||||
return getattr(plugin_module, auth_plugin_class)
|
||||
except ModuleNotFoundError as err:
|
||||
logger.warning("Requested Module was not found: %s", err)
|
||||
except ValueError as err:
|
||||
raise ProgrammingError(f"Invalid module name: {err}") from err
|
||||
raise NotSupportedError(f"Authentication plugin '{plugin_name}' is not supported")
|
||||
@@ -0,0 +1,577 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
# mypy: disable-error-code="str-bytes-safe,misc"
|
||||
|
||||
"""Kerberos Authentication Plugin."""
|
||||
|
||||
import getpass
|
||||
import os
|
||||
import struct
|
||||
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional, Tuple
|
||||
|
||||
from mysql.connector.errors import InterfaceError, ProgrammingError
|
||||
from mysql.connector.logger import logger
|
||||
|
||||
from ..authentication import ERR_STATUS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..network import MySQLSocket
|
||||
|
||||
try:
|
||||
import gssapi
|
||||
except ImportError:
|
||||
gssapi = None
|
||||
if os.name != "nt":
|
||||
raise ProgrammingError(
|
||||
"Module gssapi is required for GSSAPI authentication "
|
||||
"mechanism but was not found. Unable to authenticate "
|
||||
"with the server"
|
||||
) from None
|
||||
|
||||
try:
|
||||
import sspi
|
||||
import sspicon
|
||||
except ImportError:
|
||||
sspi = None
|
||||
sspicon = None
|
||||
|
||||
from . import MySQLAuthPlugin
|
||||
|
||||
AUTHENTICATION_PLUGIN_CLASS = (
|
||||
"MySQLSSPIKerberosAuthPlugin" if os.name == "nt" else "MySQLKerberosAuthPlugin"
|
||||
)
|
||||
|
||||
|
||||
class MySQLBaseKerberosAuthPlugin(MySQLAuthPlugin):
|
||||
"""Base class for the MySQL Kerberos authentication plugin."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Plugin official name."""
|
||||
return "authentication_kerberos_client"
|
||||
|
||||
@property
|
||||
def requires_ssl(self) -> bool:
|
||||
"""Signals whether or not SSL is required."""
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def auth_continue(
|
||||
self, tgt_auth_challenge: Optional[bytes]
|
||||
) -> Tuple[Optional[bytes], bool]:
|
||||
"""Continue with the Kerberos TGT service request.
|
||||
|
||||
With the TGT authentication service given response generate a TGT
|
||||
service request. This method must be invoked sequentially (in a loop)
|
||||
until the security context is completed and an empty response needs to
|
||||
be send to acknowledge the server.
|
||||
|
||||
Args:
|
||||
tgt_auth_challenge: the challenge for the negotiation.
|
||||
|
||||
Returns:
|
||||
tuple (bytearray TGS service request,
|
||||
bool True if context is completed otherwise False).
|
||||
"""
|
||||
|
||||
async def auth_switch_response(
|
||||
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth switch request` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Plugin provided data (extracted from a packet
|
||||
representing an `auth switch request` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth
|
||||
communication.
|
||||
"""
|
||||
logger.debug("# auth_data: %s", auth_data)
|
||||
response = self.auth_response(auth_data, ignore_auth_data=False, **kwargs)
|
||||
if response is None:
|
||||
raise InterfaceError("Got a NULL auth response")
|
||||
|
||||
logger.debug("# request: %s size: %s", response, len(response))
|
||||
await sock.write(response)
|
||||
|
||||
packet = await sock.read()
|
||||
logger.debug("# server response packet: %s", packet)
|
||||
|
||||
if packet != ERR_STATUS:
|
||||
rcode_size = 5 # Reader size for the response status code
|
||||
logger.debug("# Continue with GSSAPI authentication")
|
||||
logger.debug("# Response header: %s", packet[: rcode_size + 1])
|
||||
logger.debug("# Response size: %s", len(packet))
|
||||
logger.debug("# Negotiate a service request")
|
||||
complete = False
|
||||
tries = 0
|
||||
|
||||
while not complete and tries < 5:
|
||||
logger.debug("%s Attempt %s %s", "-" * 20, tries + 1, "-" * 20)
|
||||
logger.debug("<< Server response: %s", packet)
|
||||
logger.debug("# Response code: %s", packet[: rcode_size + 1])
|
||||
token, complete = self.auth_continue(packet[rcode_size:])
|
||||
if token:
|
||||
await sock.write(token)
|
||||
if complete:
|
||||
break
|
||||
packet = await sock.read()
|
||||
|
||||
logger.debug(">> Response to server: %s", token)
|
||||
tries += 1
|
||||
|
||||
if not complete:
|
||||
raise InterfaceError(
|
||||
f"Unable to fulfill server request after {tries} "
|
||||
f"attempts. Last server response: {packet}"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Last response from server: %s length: %d",
|
||||
packet,
|
||||
len(packet),
|
||||
)
|
||||
|
||||
# Receive OK packet from server.
|
||||
packet = await sock.read()
|
||||
logger.debug("<< Ok packet from server: %s", packet)
|
||||
|
||||
return bytes(packet)
|
||||
|
||||
|
||||
# pylint: disable=c-extension-no-member,no-member
|
||||
class MySQLKerberosAuthPlugin(MySQLBaseKerberosAuthPlugin):
|
||||
"""Implement the MySQL Kerberos authentication plugin."""
|
||||
|
||||
context: Optional[gssapi.SecurityContext] = None
|
||||
|
||||
@staticmethod
|
||||
def get_user_from_credentials() -> str:
|
||||
"""Get user from credentials without realm."""
|
||||
try:
|
||||
creds = gssapi.Credentials(usage="initiate")
|
||||
user = str(creds.name)
|
||||
if user.find("@") != -1:
|
||||
user, _ = user.split("@", 1)
|
||||
return user
|
||||
except gssapi.raw.misc.GSSError:
|
||||
return getpass.getuser()
|
||||
|
||||
@staticmethod
|
||||
def get_store() -> dict:
|
||||
"""Get a credentials store dictionary.
|
||||
|
||||
Returns:
|
||||
dict: Credentials store dictionary with the krb5 ccache name.
|
||||
|
||||
Raises:
|
||||
InterfaceError: If 'KRB5CCNAME' environment variable is empty.
|
||||
"""
|
||||
krb5ccname = os.environ.get(
|
||||
"KRB5CCNAME",
|
||||
(
|
||||
f"/tmp/krb5cc_{os.getuid()}"
|
||||
if os.name == "posix"
|
||||
else Path("%TEMP%").joinpath("krb5cc")
|
||||
),
|
||||
)
|
||||
if not krb5ccname:
|
||||
raise InterfaceError(
|
||||
"The 'KRB5CCNAME' environment variable is set to empty"
|
||||
)
|
||||
logger.debug("Using krb5 ccache name: FILE:%s", krb5ccname)
|
||||
store = {b"ccache": f"FILE:{krb5ccname}".encode("utf-8")}
|
||||
return store
|
||||
|
||||
def _acquire_cred_with_password(self, upn: str) -> gssapi.raw.creds.Creds:
|
||||
"""Acquire and store credentials through provided password.
|
||||
|
||||
Args:
|
||||
upn (str): User Principal Name.
|
||||
|
||||
Returns:
|
||||
gssapi.raw.creds.Creds: GSSAPI credentials.
|
||||
"""
|
||||
logger.debug("Attempt to acquire credentials through provided password")
|
||||
user = gssapi.Name(upn, gssapi.NameType.user)
|
||||
password = self._password.encode("utf-8")
|
||||
|
||||
try:
|
||||
acquire_cred_result = gssapi.raw.acquire_cred_with_password(
|
||||
user, password, usage="initiate"
|
||||
)
|
||||
creds = acquire_cred_result.creds
|
||||
gssapi.raw.store_cred_into(
|
||||
self.get_store(),
|
||||
creds=creds,
|
||||
mech=gssapi.MechType.kerberos,
|
||||
overwrite=True,
|
||||
set_default=True,
|
||||
)
|
||||
except gssapi.raw.misc.GSSError as err:
|
||||
raise ProgrammingError(
|
||||
f"Unable to acquire credentials with the given password: {err}"
|
||||
) from err
|
||||
return creds
|
||||
|
||||
@staticmethod
|
||||
def _parse_auth_data(packet: bytes) -> Tuple[str, str]:
|
||||
"""Parse authentication data.
|
||||
|
||||
Get the SPN and REALM from the authentication data packet.
|
||||
|
||||
Format:
|
||||
SPN string length two bytes <B1> <B2> +
|
||||
SPN string +
|
||||
UPN realm string length two bytes <B1> <B2> +
|
||||
UPN realm string
|
||||
|
||||
Returns:
|
||||
tuple: With 'spn' and 'realm'.
|
||||
"""
|
||||
spn_len = struct.unpack("<H", packet[:2])[0]
|
||||
packet = packet[2:]
|
||||
|
||||
spn = struct.unpack(f"<{spn_len}s", packet[:spn_len])[0]
|
||||
packet = packet[spn_len:]
|
||||
|
||||
realm_len = struct.unpack("<H", packet[:2])[0]
|
||||
realm = struct.unpack(f"<{realm_len}s", packet[2:])[0]
|
||||
|
||||
return spn.decode(), realm.decode()
|
||||
|
||||
def auth_response(
|
||||
self, auth_data: Optional[bytes] = None, **kwargs: Any
|
||||
) -> Optional[bytes]:
|
||||
"""Prepare the first message to the server."""
|
||||
spn = None
|
||||
realm = None
|
||||
|
||||
if auth_data and not kwargs.get("ignore_auth_data", True):
|
||||
try:
|
||||
spn, realm = self._parse_auth_data(auth_data)
|
||||
except struct.error as err:
|
||||
raise InterruptedError(f"Invalid authentication data: {err}") from err
|
||||
|
||||
if spn is None:
|
||||
return self._password.encode() + b"\x00"
|
||||
|
||||
upn = f"{self._username}@{realm}" if self._username else None
|
||||
|
||||
logger.debug("Service Principal: %s", spn)
|
||||
logger.debug("Realm: %s", realm)
|
||||
|
||||
try:
|
||||
# Attempt to retrieve credentials from cache file
|
||||
creds: Any = gssapi.Credentials(usage="initiate")
|
||||
creds_upn = str(creds.name)
|
||||
|
||||
logger.debug("Cached credentials found")
|
||||
logger.debug("Cached credentials UPN: %s", creds_upn)
|
||||
|
||||
# Remove the realm from user
|
||||
if creds_upn.find("@") != -1:
|
||||
creds_user, creds_realm = creds_upn.split("@", 1)
|
||||
else:
|
||||
creds_user = creds_upn
|
||||
creds_realm = None
|
||||
|
||||
upn = f"{self._username}@{realm}" if self._username else creds_upn
|
||||
|
||||
# The user from cached credentials matches with the given user?
|
||||
if self._username and self._username != creds_user:
|
||||
logger.debug(
|
||||
"The user from cached credentials doesn't match with the "
|
||||
"given user"
|
||||
)
|
||||
if self._password is not None:
|
||||
creds = self._acquire_cred_with_password(upn)
|
||||
if creds_realm and creds_realm != realm and self._password is not None:
|
||||
creds = self._acquire_cred_with_password(upn)
|
||||
except gssapi.raw.exceptions.ExpiredCredentialsError as err:
|
||||
if upn and self._password is not None:
|
||||
creds = self._acquire_cred_with_password(upn)
|
||||
else:
|
||||
raise InterfaceError(f"Credentials has expired: {err}") from err
|
||||
except gssapi.raw.misc.GSSError as err:
|
||||
if upn and self._password is not None:
|
||||
creds = self._acquire_cred_with_password(upn)
|
||||
else:
|
||||
raise InterfaceError(
|
||||
f"Unable to retrieve cached credentials error: {err}"
|
||||
) from err
|
||||
|
||||
flags = (
|
||||
gssapi.RequirementFlag.mutual_authentication,
|
||||
gssapi.RequirementFlag.extended_error,
|
||||
gssapi.RequirementFlag.delegate_to_peer,
|
||||
)
|
||||
name = gssapi.Name(spn, name_type=gssapi.NameType.kerberos_principal)
|
||||
cname = name.canonicalize(gssapi.MechType.kerberos)
|
||||
self.context = gssapi.SecurityContext(
|
||||
name=cname, creds=creds, flags=sum(flags), usage="initiate"
|
||||
)
|
||||
|
||||
try:
|
||||
initial_client_token: Optional[bytes] = self.context.step()
|
||||
except gssapi.raw.misc.GSSError as err:
|
||||
raise InterfaceError(f"Unable to initiate security context: {err}") from err
|
||||
|
||||
logger.debug("Initial client token: %s", initial_client_token)
|
||||
return initial_client_token
|
||||
|
||||
def auth_continue(
|
||||
self, tgt_auth_challenge: Optional[bytes]
|
||||
) -> Tuple[Optional[bytes], bool]:
|
||||
"""Continue with the Kerberos TGT service request.
|
||||
|
||||
With the TGT authentication service given response generate a TGT
|
||||
service request. This method must be invoked sequentially (in a loop)
|
||||
until the security context is completed and an empty response needs to
|
||||
be send to acknowledge the server.
|
||||
|
||||
Args:
|
||||
tgt_auth_challenge: the challenge for the negotiation.
|
||||
|
||||
Returns:
|
||||
tuple (bytearray TGS service request,
|
||||
bool True if context is completed otherwise False).
|
||||
"""
|
||||
logger.debug("tgt_auth challenge: %s", tgt_auth_challenge)
|
||||
|
||||
resp: Optional[bytes] = self.context.step(tgt_auth_challenge)
|
||||
|
||||
logger.debug("Context step response: %s", resp)
|
||||
logger.debug("Context completed?: %s", self.context.complete)
|
||||
|
||||
return resp, self.context.complete
|
||||
|
||||
def auth_accept_close_handshake(self, message: bytes) -> bytes:
|
||||
"""Accept handshake and generate closing handshake message for server.
|
||||
|
||||
This method verifies the server authenticity from the given message
|
||||
and included signature and generates the closing handshake for the
|
||||
server.
|
||||
|
||||
When this method is invoked the security context is already established
|
||||
and the client and server can send GSSAPI formated secure messages.
|
||||
|
||||
To finish the authentication handshake the server sends a message
|
||||
with the security layer availability and the maximum buffer size.
|
||||
|
||||
Since the connector only uses the GSSAPI authentication mechanism to
|
||||
authenticate the user with the server, the server will verify clients
|
||||
message signature and terminate the GSSAPI authentication and send two
|
||||
messages; an authentication acceptance b'\x01\x00\x00\x08\x01' and a
|
||||
OK packet (that must be received after sent the returned message from
|
||||
this method).
|
||||
|
||||
Args:
|
||||
message: a wrapped gssapi message from the server.
|
||||
|
||||
Returns:
|
||||
bytearray (closing handshake message to be send to the server).
|
||||
"""
|
||||
if not self.context.complete:
|
||||
raise ProgrammingError("Security context is not completed")
|
||||
logger.debug("Server message: %s", message)
|
||||
logger.debug("GSSAPI flags in use: %s", self.context.actual_flags)
|
||||
try:
|
||||
unwraped = self.context.unwrap(message)
|
||||
logger.debug("Unwraped: %s", unwraped)
|
||||
except gssapi.raw.exceptions.BadMICError as err:
|
||||
logger.debug("Unable to unwrap server message: %s", err)
|
||||
raise InterfaceError(f"Unable to unwrap server message: {err}") from err
|
||||
|
||||
logger.debug("Unwrapped server message: %s", unwraped)
|
||||
# The message contents for the clients closing message:
|
||||
# - security level 1 byte, must be always 1.
|
||||
# - conciliated buffer size 3 bytes, without importance as no
|
||||
# further GSSAPI messages will be sends.
|
||||
response = bytearray(b"\x01\x00\x00\00")
|
||||
# Closing handshake must not be encrypted.
|
||||
logger.debug("Message response: %s", response)
|
||||
wraped = self.context.wrap(response, encrypt=False)
|
||||
logger.debug(
|
||||
"Wrapped message response: %s, length: %d",
|
||||
wraped[0],
|
||||
len(wraped[0]),
|
||||
)
|
||||
|
||||
return wraped.message
|
||||
|
||||
|
||||
class MySQLSSPIKerberosAuthPlugin(MySQLBaseKerberosAuthPlugin):
|
||||
"""Implement the MySQL Kerberos authentication plugin with Windows SSPI"""
|
||||
|
||||
context: Any = None
|
||||
clientauth: Any = None
|
||||
|
||||
@staticmethod
|
||||
def _parse_auth_data(packet: bytes) -> Tuple[str, str]:
|
||||
"""Parse authentication data.
|
||||
|
||||
Get the SPN and REALM from the authentication data packet.
|
||||
|
||||
Format:
|
||||
SPN string length two bytes <B1> <B2> +
|
||||
SPN string +
|
||||
UPN realm string length two bytes <B1> <B2> +
|
||||
UPN realm string
|
||||
|
||||
Returns:
|
||||
tuple: With 'spn' and 'realm'.
|
||||
"""
|
||||
spn_len = struct.unpack("<H", packet[:2])[0]
|
||||
packet = packet[2:]
|
||||
|
||||
spn = struct.unpack(f"<{spn_len}s", packet[:spn_len])[0]
|
||||
packet = packet[spn_len:]
|
||||
|
||||
realm_len = struct.unpack("<H", packet[:2])[0]
|
||||
realm = struct.unpack(f"<{realm_len}s", packet[2:])[0]
|
||||
|
||||
return spn.decode(), realm.decode()
|
||||
|
||||
def auth_response(
|
||||
self, auth_data: Optional[bytes] = None, **kwargs: Any
|
||||
) -> Optional[bytes]:
|
||||
"""Prepare the first message to the server.
|
||||
|
||||
Args:
|
||||
kwargs:
|
||||
ignore_auth_data (bool): if True, the provided auth data is ignored.
|
||||
"""
|
||||
logger.debug("auth_response for sspi")
|
||||
spn = None
|
||||
realm = None
|
||||
|
||||
if auth_data and not kwargs.get("ignore_auth_data", True):
|
||||
try:
|
||||
spn, realm = self._parse_auth_data(auth_data)
|
||||
except struct.error as err:
|
||||
raise InterruptedError(f"Invalid authentication data: {err}") from err
|
||||
|
||||
logger.debug("Service Principal: %s", spn)
|
||||
logger.debug("Realm: %s", realm)
|
||||
|
||||
if sspicon is None or sspi is None:
|
||||
raise ProgrammingError(
|
||||
'Package "pywin32" (Python for Win32 (pywin32) extensions)'
|
||||
" is not installed."
|
||||
)
|
||||
|
||||
flags = (sspicon.ISC_REQ_MUTUAL_AUTH, sspicon.ISC_REQ_DELEGATE)
|
||||
|
||||
if self._username and self._password:
|
||||
_auth_info = (self._username, realm, self._password)
|
||||
else:
|
||||
_auth_info = None
|
||||
|
||||
targetspn = spn
|
||||
logger.debug("targetspn: %s", targetspn)
|
||||
logger.debug("_auth_info is None: %s", _auth_info is None)
|
||||
|
||||
# The Security Support Provider Interface (SSPI) is an interface
|
||||
# that allows us to choose from a set of SSPs available in the
|
||||
# system; the idea of SSPI is to keep interface consistent no
|
||||
# matter what back end (a.k.a., SSP) we choose.
|
||||
|
||||
# When using SSPI we should not use Kerberos directly as SSP,
|
||||
# as remarked in [2], but we can use it indirectly via another
|
||||
# SSP named Negotiate that acts as an application layer between
|
||||
# SSPI and the other SSPs [1].
|
||||
|
||||
# Negotiate can select between Kerberos and NTLM on the fly;
|
||||
# it chooses Kerberos unless it cannot be used by one of the
|
||||
# systems involved in the authentication or the calling
|
||||
# application did not provide sufficient information to use
|
||||
# Kerberos.
|
||||
|
||||
# prefix: https://docs.microsoft.com/en-us/windows/win32/secauthn
|
||||
# [1] prefix/microsoft-negotiate?source=recommendations
|
||||
# [2] prefix/microsoft-kerberos?source=recommendations
|
||||
self.clientauth = sspi.ClientAuth(
|
||||
"Negotiate",
|
||||
targetspn=targetspn,
|
||||
auth_info=_auth_info,
|
||||
scflags=sum(flags),
|
||||
datarep=sspicon.SECURITY_NETWORK_DREP,
|
||||
)
|
||||
|
||||
try:
|
||||
data = None
|
||||
err, out_buf = self.clientauth.authorize(data)
|
||||
logger.debug("Context step err: %s", err)
|
||||
logger.debug("Context step out_buf: %s", out_buf)
|
||||
logger.debug("Context completed?: %s", self.clientauth.authenticated)
|
||||
initial_client_token = out_buf[0].Buffer
|
||||
logger.debug("pkg_info: %s", self.clientauth.pkg_info)
|
||||
except Exception as err:
|
||||
raise InterfaceError(f"Unable to initiate security context: {err}") from err
|
||||
|
||||
logger.debug("Initial client token: %s", initial_client_token)
|
||||
return initial_client_token
|
||||
|
||||
def auth_continue(
|
||||
self, tgt_auth_challenge: Optional[bytes]
|
||||
) -> Tuple[Optional[bytes], bool]:
|
||||
"""Continue with the Kerberos TGT service request.
|
||||
|
||||
With the TGT authentication service given response generate a TGT
|
||||
service request. This method must be invoked sequentially (in a loop)
|
||||
until the security context is completed and an empty response needs to
|
||||
be send to acknowledge the server.
|
||||
|
||||
Args:
|
||||
tgt_auth_challenge: the challenge for the negotiation.
|
||||
|
||||
Returns:
|
||||
tuple (bytearray TGS service request,
|
||||
bool True if context is completed otherwise False).
|
||||
"""
|
||||
logger.debug("tgt_auth challenge: %s", tgt_auth_challenge)
|
||||
|
||||
err, out_buf = self.clientauth.authorize(tgt_auth_challenge)
|
||||
|
||||
logger.debug("Context step err: %s", err)
|
||||
logger.debug("Context step out_buf: %s", out_buf)
|
||||
resp = out_buf[0].Buffer
|
||||
logger.debug("Context step resp: %s", resp)
|
||||
logger.debug("Context completed?: %s", self.clientauth.authenticated)
|
||||
|
||||
return resp, self.clientauth.authenticated
|
||||
@@ -0,0 +1,595 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""LDAP SASL Authentication Plugin."""
|
||||
|
||||
import hmac
|
||||
|
||||
from base64 import b64decode, b64encode
|
||||
from hashlib import sha1, sha256
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from mysql.connector.authentication import ERR_STATUS
|
||||
from mysql.connector.errors import InterfaceError, ProgrammingError
|
||||
from mysql.connector.logger import logger
|
||||
from mysql.connector.types import StrOrBytes
|
||||
from mysql.connector.utils import (
|
||||
normalize_unicode_string as norm_ustr,
|
||||
validate_normalized_unicode_string as valid_norm,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..network import MySQLSocket
|
||||
|
||||
try:
|
||||
import gssapi
|
||||
except ImportError:
|
||||
raise ProgrammingError(
|
||||
"Module gssapi is required for GSSAPI authentication "
|
||||
"mechanism but was not found. Unable to authenticate "
|
||||
"with the server"
|
||||
) from None
|
||||
|
||||
from . import MySQLAuthPlugin
|
||||
|
||||
AUTHENTICATION_PLUGIN_CLASS = "MySQLLdapSaslPasswordAuthPlugin"
|
||||
|
||||
|
||||
# pylint: disable=c-extension-no-member,no-member
|
||||
class MySQLLdapSaslPasswordAuthPlugin(MySQLAuthPlugin):
|
||||
"""Class implementing the MySQL ldap sasl authentication plugin.
|
||||
|
||||
The MySQL's ldap sasl authentication plugin support two authentication
|
||||
methods SCRAM-SHA-1 and GSSAPI (using Kerberos). This implementation only
|
||||
support SCRAM-SHA-1 and SCRAM-SHA-256.
|
||||
|
||||
SCRAM-SHA-1 amd SCRAM-SHA-256
|
||||
This method requires 2 messages from client and 2 responses from
|
||||
server.
|
||||
|
||||
The first message from client will be generated by prepare_password(),
|
||||
after receive the response from the server, it is required that this
|
||||
response is passed back to auth_continue() which will return the
|
||||
second message from the client. After send this second message to the
|
||||
server, the second server respond needs to be passed to auth_finalize()
|
||||
to finish the authentication process.
|
||||
"""
|
||||
|
||||
sasl_mechanisms: List[str] = ["SCRAM-SHA-1", "SCRAM-SHA-256", "GSSAPI"]
|
||||
def_digest_mode: Callable = sha1
|
||||
client_nonce: Optional[str] = None
|
||||
client_salt: Any = None
|
||||
server_salt: Optional[str] = None
|
||||
krb_service_principal: Optional[str] = None
|
||||
iterations: int = 0
|
||||
server_auth_var: Optional[str] = None
|
||||
target_name: Optional[gssapi.Name] = None
|
||||
ctx: gssapi.SecurityContext = None
|
||||
servers_first: Optional[str] = None
|
||||
server_nonce: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def _xor(bytes1: bytes, bytes2: bytes) -> bytes:
|
||||
return bytes([b1 ^ b2 for b1, b2 in zip(bytes1, bytes2)])
|
||||
|
||||
def _hmac(self, password: bytes, salt: bytes) -> bytes:
|
||||
digest_maker = hmac.new(password, salt, self.def_digest_mode)
|
||||
return digest_maker.digest()
|
||||
|
||||
def _hi(self, password: str, salt: bytes, count: int) -> bytes:
|
||||
"""Prepares Hi
|
||||
Hi(password, salt, iterations) where Hi(p,s,i) is defined as
|
||||
PBKDF2 (HMAC, p, s, i, output length of H).
|
||||
"""
|
||||
pw = password.encode()
|
||||
hi = self._hmac(pw, salt + b"\x00\x00\x00\x01")
|
||||
aux = hi
|
||||
for _ in range(count - 1):
|
||||
aux = self._hmac(pw, aux)
|
||||
hi = self._xor(hi, aux)
|
||||
return hi
|
||||
|
||||
@staticmethod
|
||||
def _normalize(string: str) -> str:
|
||||
norm_str = norm_ustr(string)
|
||||
broken_rule = valid_norm(norm_str)
|
||||
if broken_rule is not None:
|
||||
raise InterfaceError(f"broken_rule: {broken_rule}")
|
||||
return norm_str
|
||||
|
||||
def _first_message(self) -> bytes:
|
||||
"""This method generates the first message to the server to start the
|
||||
|
||||
The client-first message consists of a gs2-header,
|
||||
the desired username, and a randomly generated client nonce cnonce.
|
||||
|
||||
The first message from the server has the form:
|
||||
b'n,a=<user_name>,n=<user_name>,r=<client_nonce>
|
||||
|
||||
Returns client's first message
|
||||
"""
|
||||
cfm_fprnat = "n,a={user_name},n={user_name},r={client_nonce}"
|
||||
self.client_nonce = str(uuid4()).replace("-", "")
|
||||
cfm: StrOrBytes = cfm_fprnat.format(
|
||||
user_name=self._normalize(self._username),
|
||||
client_nonce=self.client_nonce,
|
||||
)
|
||||
|
||||
if isinstance(cfm, str):
|
||||
cfm = cfm.encode("utf8")
|
||||
return cfm
|
||||
|
||||
def _first_message_krb(self) -> Optional[bytes]:
|
||||
"""Get a TGT Authentication request and initiates security context.
|
||||
|
||||
This method will contact the Kerberos KDC in order of obtain a TGT.
|
||||
"""
|
||||
user_name = gssapi.raw.names.import_name(
|
||||
self._username.encode("utf8"), name_type=gssapi.NameType.user
|
||||
)
|
||||
|
||||
# Use defaults store = {'ccache': 'FILE:/tmp/krb5cc_1000'}#,
|
||||
# 'keytab':'/etc/some.keytab' }
|
||||
# Attempt to retrieve credential from default cache file.
|
||||
try:
|
||||
cred: Any = gssapi.Credentials()
|
||||
logger.debug(
|
||||
"# Stored credentials found, if password was given it will be ignored."
|
||||
)
|
||||
try:
|
||||
# validate credentials has not expired.
|
||||
cred.lifetime
|
||||
except gssapi.raw.exceptions.ExpiredCredentialsError as err:
|
||||
logger.warning(" Credentials has expired: %s", err)
|
||||
cred.acquire(user_name)
|
||||
raise InterfaceError(f"Credentials has expired: {err}") from err
|
||||
except gssapi.raw.misc.GSSError as err:
|
||||
if not self._password:
|
||||
raise InterfaceError(
|
||||
f"Unable to retrieve stored credentials error: {err}"
|
||||
) from err
|
||||
try:
|
||||
logger.debug("# Attempt to retrieve credentials with given password")
|
||||
acquire_cred_result = gssapi.raw.acquire_cred_with_password(
|
||||
user_name,
|
||||
self._password.encode("utf8"),
|
||||
usage="initiate",
|
||||
)
|
||||
cred = acquire_cred_result[0]
|
||||
except gssapi.raw.misc.GSSError as err2:
|
||||
raise ProgrammingError(
|
||||
f"Unable to retrieve credentials with the given password: {err2}"
|
||||
) from err
|
||||
|
||||
flags_l = (
|
||||
gssapi.RequirementFlag.mutual_authentication,
|
||||
gssapi.RequirementFlag.extended_error,
|
||||
gssapi.RequirementFlag.delegate_to_peer,
|
||||
)
|
||||
|
||||
if self.krb_service_principal:
|
||||
service_principal = self.krb_service_principal
|
||||
else:
|
||||
service_principal = "ldap/ldapauth"
|
||||
logger.debug("# service principal: %s", service_principal)
|
||||
servk = gssapi.Name(
|
||||
service_principal, name_type=gssapi.NameType.kerberos_principal
|
||||
)
|
||||
self.target_name = servk
|
||||
self.ctx = gssapi.SecurityContext(
|
||||
name=servk, creds=cred, flags=sum(flags_l), usage="initiate"
|
||||
)
|
||||
|
||||
try:
|
||||
# step() returns bytes | None, see documentation,
|
||||
# so this method could return a NULL payload.
|
||||
# ref: https://pythongssapi.github.io/<suffix>
|
||||
# suffix: python-gssapi/latest/gssapi.html#gssapi.sec_contexts.SecurityContext
|
||||
initial_client_token = self.ctx.step()
|
||||
except gssapi.raw.misc.GSSError as err:
|
||||
raise InterfaceError(f"Unable to initiate security context: {err}") from err
|
||||
|
||||
logger.debug("# initial client token: %s", initial_client_token)
|
||||
return initial_client_token
|
||||
|
||||
def auth_continue_krb(
|
||||
self, tgt_auth_challenge: Optional[bytes]
|
||||
) -> Tuple[Optional[bytes], bool]:
|
||||
"""Continue with the Kerberos TGT service request.
|
||||
|
||||
With the TGT authentication service given response generate a TGT
|
||||
service request. This method must be invoked sequentially (in a loop)
|
||||
until the security context is completed and an empty response needs to
|
||||
be send to acknowledge the server.
|
||||
|
||||
Args:
|
||||
tgt_auth_challenge the challenge for the negotiation.
|
||||
|
||||
Returns: tuple (bytearray TGS service request,
|
||||
bool True if context is completed otherwise False).
|
||||
"""
|
||||
logger.debug("tgt_auth challenge: %s", tgt_auth_challenge)
|
||||
|
||||
resp = self.ctx.step(tgt_auth_challenge)
|
||||
logger.debug("# context step response: %s", resp)
|
||||
logger.debug("# context completed?: %s", self.ctx.complete)
|
||||
|
||||
return resp, self.ctx.complete
|
||||
|
||||
def auth_accept_close_handshake(self, message: bytes) -> bytes:
|
||||
"""Accept handshake and generate closing handshake message for server.
|
||||
|
||||
This method verifies the server authenticity from the given message
|
||||
and included signature and generates the closing handshake for the
|
||||
server.
|
||||
|
||||
When this method is invoked the security context is already established
|
||||
and the client and server can send GSSAPI formated secure messages.
|
||||
|
||||
To finish the authentication handshake the server sends a message
|
||||
with the security layer availability and the maximum buffer size.
|
||||
|
||||
Since the connector only uses the GSSAPI authentication mechanism to
|
||||
authenticate the user with the server, the server will verify clients
|
||||
message signature and terminate the GSSAPI authentication and send two
|
||||
messages; an authentication acceptance b'\x01\x00\x00\x08\x01' and a
|
||||
OK packet (that must be received after sent the returned message from
|
||||
this method).
|
||||
|
||||
Args:
|
||||
message a wrapped hssapi message from the server.
|
||||
|
||||
Returns: bytearray closing handshake message to be send to the server.
|
||||
"""
|
||||
if not self.ctx.complete:
|
||||
raise ProgrammingError("Security context is not completed.")
|
||||
logger.debug("# servers message: %s", message)
|
||||
logger.debug("# GSSAPI flags in use: %s", self.ctx.actual_flags)
|
||||
try:
|
||||
unwraped = self.ctx.unwrap(message)
|
||||
logger.debug("# unwraped: %s", unwraped)
|
||||
except gssapi.raw.exceptions.BadMICError as err:
|
||||
raise InterfaceError(f"Unable to unwrap server message: {err}") from err
|
||||
|
||||
logger.debug("# unwrapped server message: %s", unwraped)
|
||||
# The message contents for the clients closing message:
|
||||
# - security level 1 byte, must be always 1.
|
||||
# - conciliated buffer size 3 bytes, without importance as no
|
||||
# further GSSAPI messages will be sends.
|
||||
response = bytearray(b"\x01\x00\x00\00")
|
||||
# Closing handshake must not be encrypted.
|
||||
logger.debug("# message response: %s", response)
|
||||
wraped = self.ctx.wrap(response, encrypt=False)
|
||||
logger.debug(
|
||||
"# wrapped message response: %s, length: %d",
|
||||
wraped[0],
|
||||
len(wraped[0]),
|
||||
)
|
||||
|
||||
return wraped.message
|
||||
|
||||
def auth_response(
|
||||
self,
|
||||
auth_data: bytes,
|
||||
**kwargs: Any,
|
||||
) -> Optional[bytes]:
|
||||
"""This method will prepare the fist message to the server.
|
||||
|
||||
Returns bytes to send to the server as the first message.
|
||||
"""
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self._auth_data = auth_data
|
||||
|
||||
auth_mechanism = self._auth_data.decode()
|
||||
logger.debug("read_method_name_from_server: %s", auth_mechanism)
|
||||
if auth_mechanism not in self.sasl_mechanisms:
|
||||
auth_mechanisms = '", "'.join(self.sasl_mechanisms[:-1])
|
||||
raise InterfaceError(
|
||||
f'The sasl authentication method "{auth_mechanism}" requested '
|
||||
f'from the server is not supported. Only "{auth_mechanisms}" '
|
||||
f'and "{self.sasl_mechanisms[-1]}" are supported'
|
||||
)
|
||||
|
||||
if b"GSSAPI" in self._auth_data:
|
||||
return self._first_message_krb()
|
||||
|
||||
if self._auth_data == b"SCRAM-SHA-256":
|
||||
self.def_digest_mode = sha256
|
||||
|
||||
return self._first_message()
|
||||
|
||||
def _second_message(self) -> bytes:
|
||||
"""This method generates the second message to the server
|
||||
|
||||
Second message consist on the concatenation of the client and the
|
||||
server nonce, and cproof.
|
||||
|
||||
c=<n,a=<user_name>>,r=<server_nonce>,p=<client_proof>
|
||||
where:
|
||||
<client_proof>: xor(<client_key>, <client_signature>)
|
||||
|
||||
<client_key>: hmac(salted_password, b"Client Key")
|
||||
<client_signature>: hmac(<stored_key>, <auth_msg>)
|
||||
<stored_key>: h(<client_key>)
|
||||
<auth_msg>: <client_first_no_header>,<servers_first>,
|
||||
c=<client_header>,r=<server_nonce>
|
||||
<client_first_no_header>: n=<username>r=<client_nonce>
|
||||
"""
|
||||
if not self._auth_data:
|
||||
raise InterfaceError("Missing authentication data (seed)")
|
||||
|
||||
passw = self._normalize(self._password)
|
||||
salted_password = self._hi(passw, b64decode(self.server_salt), self.iterations)
|
||||
logger.debug("salted_password: %s", b64encode(salted_password).decode())
|
||||
|
||||
client_key = self._hmac(salted_password, b"Client Key")
|
||||
logger.debug("client_key: %s", b64encode(client_key).decode())
|
||||
|
||||
stored_key = self.def_digest_mode(client_key).digest()
|
||||
logger.debug("stored_key: %s", b64encode(stored_key).decode())
|
||||
|
||||
server_key = self._hmac(salted_password, b"Server Key")
|
||||
logger.debug("server_key: %s", b64encode(server_key).decode())
|
||||
|
||||
client_first_no_header = ",".join(
|
||||
[
|
||||
f"n={self._normalize(self._username)}",
|
||||
f"r={self.client_nonce}",
|
||||
]
|
||||
)
|
||||
logger.debug("client_first_no_header: %s", client_first_no_header)
|
||||
|
||||
client_header = b64encode(
|
||||
f"n,a={self._normalize(self._username)},".encode()
|
||||
).decode()
|
||||
|
||||
auth_msg = ",".join(
|
||||
[
|
||||
client_first_no_header,
|
||||
self.servers_first,
|
||||
f"c={client_header}",
|
||||
f"r={self.server_nonce}",
|
||||
]
|
||||
)
|
||||
logger.debug("auth_msg: %s", auth_msg)
|
||||
|
||||
client_signature = self._hmac(stored_key, auth_msg.encode())
|
||||
logger.debug("client_signature: %s", b64encode(client_signature).decode())
|
||||
|
||||
client_proof = self._xor(client_key, client_signature)
|
||||
logger.debug("client_proof: %s", b64encode(client_proof).decode())
|
||||
|
||||
self.server_auth_var = b64encode(
|
||||
self._hmac(server_key, auth_msg.encode())
|
||||
).decode()
|
||||
logger.debug("server_auth_var: %s", self.server_auth_var)
|
||||
|
||||
msg = ",".join(
|
||||
[
|
||||
f"c={client_header}",
|
||||
f"r={self.server_nonce}",
|
||||
f"p={b64encode(client_proof).decode()}",
|
||||
]
|
||||
)
|
||||
logger.debug("second_message: %s", msg)
|
||||
return msg.encode()
|
||||
|
||||
def _validate_first_reponse(self, servers_first: bytes) -> None:
|
||||
"""Validates first message from the server.
|
||||
|
||||
Extracts the server's salt and iterations from the servers 1st response.
|
||||
First message from the server is in the form:
|
||||
<server_salt>,i=<iterations>
|
||||
"""
|
||||
if not servers_first or not isinstance(servers_first, (bytearray, bytes)):
|
||||
raise InterfaceError(f"Unexpected server message: {repr(servers_first)}")
|
||||
try:
|
||||
servers_first_str = servers_first.decode()
|
||||
self.servers_first = servers_first_str
|
||||
r_server_nonce, s_salt, i_counter = servers_first_str.split(",")
|
||||
except ValueError:
|
||||
raise InterfaceError(
|
||||
f"Unexpected server message: {servers_first_str}"
|
||||
) from None
|
||||
if (
|
||||
not r_server_nonce.startswith("r=")
|
||||
or not s_salt.startswith("s=")
|
||||
or not i_counter.startswith("i=")
|
||||
):
|
||||
raise InterfaceError(
|
||||
f"Incomplete reponse from the server: {servers_first_str}"
|
||||
)
|
||||
if self.client_nonce in r_server_nonce:
|
||||
self.server_nonce = r_server_nonce[2:]
|
||||
logger.debug("server_nonce: %s", self.server_nonce)
|
||||
else:
|
||||
raise InterfaceError(
|
||||
"Unable to authenticate response: response not well formed "
|
||||
f"{servers_first_str}"
|
||||
)
|
||||
self.server_salt = s_salt[2:]
|
||||
logger.debug(
|
||||
"server_salt: %s length: %s",
|
||||
self.server_salt,
|
||||
len(self.server_salt),
|
||||
)
|
||||
try:
|
||||
i_counter = i_counter[2:]
|
||||
logger.debug("iterations: %s", i_counter)
|
||||
self.iterations = int(i_counter)
|
||||
except Exception as err:
|
||||
raise InterfaceError(
|
||||
f"Unable to authenticate: iterations not found {servers_first_str}"
|
||||
) from err
|
||||
|
||||
def auth_continue(self, servers_first_response: bytes) -> bytes:
|
||||
"""return the second message from the client.
|
||||
|
||||
Returns bytes to send to the server as the second message.
|
||||
"""
|
||||
self._validate_first_reponse(servers_first_response)
|
||||
return self._second_message()
|
||||
|
||||
def _validate_second_reponse(self, servers_second: bytearray) -> bool:
|
||||
"""Validates second message from the server.
|
||||
|
||||
The client and the server prove to each other they have the same Auth
|
||||
variable.
|
||||
|
||||
The second message from the server consist of the server's proof:
|
||||
server_proof = HMAC(<server_key>, <auth_msg>)
|
||||
where:
|
||||
<server_key>: hmac(<salted_password>, b"Server Key")
|
||||
<auth_msg>: <client_first_no_header>,<servers_first>,
|
||||
c=<client_header>,r=<server_nonce>
|
||||
|
||||
Our server_proof must be equal to the Auth variable send on this second
|
||||
response.
|
||||
"""
|
||||
if (
|
||||
not servers_second
|
||||
or not isinstance(servers_second, bytearray)
|
||||
or len(servers_second) <= 2
|
||||
or not servers_second.startswith(b"v=")
|
||||
):
|
||||
raise InterfaceError("The server's proof is not well formated")
|
||||
server_var = servers_second[2:].decode()
|
||||
logger.debug("server auth variable: %s", server_var)
|
||||
return self.server_auth_var == server_var
|
||||
|
||||
def auth_finalize(self, servers_second_response: bytearray) -> bool:
|
||||
"""finalize the authentication process.
|
||||
|
||||
Raises InterfaceError if the ervers_second_response is invalid.
|
||||
|
||||
Returns True in successful authentication False otherwise.
|
||||
"""
|
||||
if not self._validate_second_reponse(servers_second_response):
|
||||
raise InterfaceError(
|
||||
"Authentication failed: Unable to proof server identity"
|
||||
)
|
||||
return True
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Plugin official name."""
|
||||
return "authentication_ldap_sasl_client"
|
||||
|
||||
@property
|
||||
def requires_ssl(self) -> bool:
|
||||
"""Signals whether or not SSL is required."""
|
||||
return False
|
||||
|
||||
async def auth_switch_response(
|
||||
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth switch request` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Plugin provided data (extracted from a packet
|
||||
representing an `auth switch request` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth
|
||||
communication.
|
||||
"""
|
||||
logger.debug("# auth_data: %s", auth_data)
|
||||
self.krb_service_principal = kwargs.get("krb_service_principal")
|
||||
|
||||
response = self.auth_response(auth_data, **kwargs)
|
||||
if response is None:
|
||||
raise InterfaceError("Got a NULL auth response")
|
||||
|
||||
logger.debug("# request: %s size: %s", response, len(response))
|
||||
await sock.write(response)
|
||||
|
||||
packet = await sock.read()
|
||||
logger.debug("# server response packet: %s", packet)
|
||||
|
||||
if len(packet) >= 6 and packet[5] == 114 and packet[6] == 61: # 'r' and '='
|
||||
# Continue with sasl authentication
|
||||
dec_response = packet[5:]
|
||||
cresponse = self.auth_continue(dec_response)
|
||||
await sock.write(cresponse)
|
||||
packet = await sock.read()
|
||||
if packet[5] == 118 and packet[6] == 61: # 'v' and '='
|
||||
if self.auth_finalize(packet[5:]):
|
||||
# receive packed OK
|
||||
packet = await sock.read()
|
||||
elif auth_data == b"GSSAPI" and packet[4] != ERR_STATUS:
|
||||
rcode_size = 5 # header size for the response status code.
|
||||
logger.debug("# Continue with sasl GSSAPI authentication")
|
||||
logger.debug("# response header: %s", packet[: rcode_size + 1])
|
||||
logger.debug("# response size: %s", len(packet))
|
||||
|
||||
logger.debug("# Negotiate a service request")
|
||||
complete = False
|
||||
tries = 0 # To avoid a infinite loop attempt no more than feedback messages
|
||||
while not complete and tries < 5:
|
||||
logger.debug("%s Attempt %s %s", "-" * 20, tries + 1, "-" * 20)
|
||||
logger.debug("<< server response: %s", packet)
|
||||
logger.debug("# response code: %s", packet[: rcode_size + 1])
|
||||
step, complete = self.auth_continue_krb(packet[rcode_size:])
|
||||
logger.debug(" >> response to server: %s", step)
|
||||
await sock.write(step or b"")
|
||||
packet = await sock.read()
|
||||
tries += 1
|
||||
if not complete:
|
||||
raise InterfaceError(
|
||||
f"Unable to fulfill server request after {tries} "
|
||||
f"attempts. Last server response: {packet}"
|
||||
)
|
||||
logger.debug(
|
||||
" last GSSAPI response from server: %s length: %d",
|
||||
packet,
|
||||
len(packet),
|
||||
)
|
||||
last_step = self.auth_accept_close_handshake(packet[rcode_size:])
|
||||
logger.debug(
|
||||
" >> last response to server: %s length: %d",
|
||||
last_step,
|
||||
len(last_step),
|
||||
)
|
||||
await sock.write(last_step)
|
||||
# Receive final handshake from server
|
||||
packet = await sock.read()
|
||||
logger.debug("<< final handshake from server: %s", packet)
|
||||
|
||||
# receive OK packet from server.
|
||||
packet = await sock.read()
|
||||
logger.debug("<< ok packet from server: %s", packet)
|
||||
|
||||
return bytes(packet)
|
||||
|
||||
|
||||
# pylint: enable=c-extension-no-member,no-member
|
||||
@@ -0,0 +1,234 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
# mypy: disable-error-code="arg-type,union-attr,call-arg"
|
||||
|
||||
"""OCI Authentication Plugin."""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
from base64 import b64encode
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from mysql.connector import errors
|
||||
from mysql.connector.logger import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..network import MySQLSocket
|
||||
|
||||
try:
|
||||
from cryptography.exceptions import UnsupportedAlgorithm
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
from cryptography.hazmat.primitives.asymmetric.types import PRIVATE_KEY_TYPES
|
||||
except ImportError:
|
||||
raise errors.ProgrammingError("Package 'cryptography' is not installed") from None
|
||||
|
||||
try:
|
||||
from oci import config, exceptions
|
||||
except ImportError:
|
||||
raise errors.ProgrammingError(
|
||||
"Package 'oci' (Oracle Cloud Infrastructure Python SDK) is not installed"
|
||||
) from None
|
||||
|
||||
from . import MySQLAuthPlugin
|
||||
|
||||
AUTHENTICATION_PLUGIN_CLASS = "MySQLOCIAuthPlugin"
|
||||
OCI_SECURITY_TOKEN_MAX_SIZE = 10 * 1024 # In bytes
|
||||
OCI_SECURITY_TOKEN_TOO_LARGE = "Ephemeral security token is too large (10KB max)"
|
||||
OCI_SECURITY_TOKEN_FILE_NOT_AVAILABLE = (
|
||||
"Ephemeral security token file ('security_token_file') could not be read"
|
||||
)
|
||||
OCI_PROFILE_MISSING_PROPERTIES = (
|
||||
"OCI configuration file does not contain a 'fingerprint' or 'key_file' entry"
|
||||
)
|
||||
|
||||
|
||||
class MySQLOCIAuthPlugin(MySQLAuthPlugin):
|
||||
"""Implement the MySQL OCI IAM authentication plugin."""
|
||||
|
||||
context: Any = None
|
||||
oci_config_profile: str = "DEFAULT"
|
||||
oci_config_file: str = config.DEFAULT_LOCATION
|
||||
|
||||
@staticmethod
|
||||
def _prepare_auth_response(signature: bytes, oci_config: Dict[str, Any]) -> str:
|
||||
"""Prepare client's authentication response
|
||||
|
||||
Prepares client's authentication response in JSON format
|
||||
Args:
|
||||
signature (bytes): server's nonce to be signed by client.
|
||||
oci_config (dict): OCI configuration object.
|
||||
|
||||
Returns:
|
||||
str: JSON string with the following format:
|
||||
{"fingerprint": str, "signature": str, "token": base64.base64.base64}
|
||||
|
||||
Raises:
|
||||
ProgrammingError: If the ephemeral security token file can't be open or the
|
||||
token is too large.
|
||||
"""
|
||||
signature_64 = b64encode(signature)
|
||||
auth_response = {
|
||||
"fingerprint": oci_config["fingerprint"],
|
||||
"signature": signature_64.decode(),
|
||||
}
|
||||
|
||||
# The security token, if it exists, should be a JWT (JSON Web Token), consisted
|
||||
# of a base64-encoded header, body, and signature, separated by '.',
|
||||
# e.g. "Base64.Base64.Base64", stored in a file at the path specified by the
|
||||
# security_token_file configuration property
|
||||
if oci_config.get("security_token_file"):
|
||||
try:
|
||||
security_token_file = Path(oci_config["security_token_file"])
|
||||
# Check if token exceeds the maximum size
|
||||
if security_token_file.stat().st_size > OCI_SECURITY_TOKEN_MAX_SIZE:
|
||||
raise errors.ProgrammingError(OCI_SECURITY_TOKEN_TOO_LARGE)
|
||||
auth_response["token"] = security_token_file.read_text(encoding="utf-8")
|
||||
except (OSError, UnicodeError) as err:
|
||||
raise errors.ProgrammingError(
|
||||
OCI_SECURITY_TOKEN_FILE_NOT_AVAILABLE
|
||||
) from err
|
||||
return json.dumps(auth_response, separators=(",", ":"))
|
||||
|
||||
@staticmethod
|
||||
def _get_private_key(key_path: str) -> PRIVATE_KEY_TYPES:
|
||||
"""Get the private_key form the given location"""
|
||||
try:
|
||||
with open(os.path.expanduser(key_path), "rb") as key_file:
|
||||
private_key = serialization.load_pem_private_key(
|
||||
key_file.read(),
|
||||
password=None,
|
||||
)
|
||||
except (TypeError, OSError, ValueError, UnsupportedAlgorithm) as err:
|
||||
raise errors.ProgrammingError(
|
||||
"An error occurred while reading the API_KEY from "
|
||||
f'"{key_path}": {err}'
|
||||
)
|
||||
|
||||
return private_key
|
||||
|
||||
def _get_valid_oci_config(self) -> Dict[str, Any]:
|
||||
"""Get a valid OCI config from the given configuration file path"""
|
||||
error_list = []
|
||||
req_keys = {
|
||||
"fingerprint": (lambda x: len(x) > 32),
|
||||
"key_file": (lambda x: os.path.exists(os.path.expanduser(x))),
|
||||
}
|
||||
|
||||
oci_config: Dict[str, Any] = {}
|
||||
try:
|
||||
# key_file is validated by oci.config if present
|
||||
oci_config = config.from_file(
|
||||
self.oci_config_file or config.DEFAULT_LOCATION,
|
||||
self.oci_config_profile or "DEFAULT",
|
||||
)
|
||||
for req_key, req_value in req_keys.items():
|
||||
try:
|
||||
# Verify parameter in req_key is present and valid
|
||||
if oci_config[req_key] and not req_value(oci_config[req_key]):
|
||||
error_list.append(f'Parameter "{req_key}" is invalid')
|
||||
except KeyError:
|
||||
error_list.append(f"Does not contain parameter {req_key}")
|
||||
except (
|
||||
exceptions.ConfigFileNotFound,
|
||||
exceptions.InvalidConfig,
|
||||
exceptions.InvalidKeyFilePath,
|
||||
exceptions.InvalidPrivateKey,
|
||||
exceptions.ProfileNotFound,
|
||||
) as err:
|
||||
error_list.append(str(err))
|
||||
|
||||
# Raise errors if any
|
||||
if error_list:
|
||||
raise errors.ProgrammingError(
|
||||
f"Invalid oci-config-file: {self.oci_config_file}. "
|
||||
f"Errors found: {error_list}"
|
||||
)
|
||||
|
||||
return oci_config
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Plugin official name."""
|
||||
return "authentication_oci_client"
|
||||
|
||||
@property
|
||||
def requires_ssl(self) -> bool:
|
||||
"""Signals whether or not SSL is required."""
|
||||
return False
|
||||
|
||||
def auth_response(self, auth_data: bytes, **kwargs: Any) -> Optional[bytes]:
|
||||
"""Prepare authentication string for the server."""
|
||||
logger.debug("server nonce: %s, len %d", auth_data, len(auth_data))
|
||||
|
||||
oci_config = self._get_valid_oci_config()
|
||||
|
||||
private_key = self._get_private_key(oci_config["key_file"])
|
||||
signature = private_key.sign(auth_data, padding.PKCS1v15(), hashes.SHA256())
|
||||
|
||||
auth_response = self._prepare_auth_response(signature, oci_config)
|
||||
logger.debug("authentication response: %s", auth_response)
|
||||
return auth_response.encode()
|
||||
|
||||
async def auth_switch_response(
|
||||
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth switch request` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Plugin provided data (extracted from a packet
|
||||
representing an `auth switch request` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth
|
||||
communication.
|
||||
"""
|
||||
self.oci_config_file = kwargs.get("oci_config_file", "DEFAULT")
|
||||
self.oci_config_profile = kwargs.get(
|
||||
"oci_config_profile", config.DEFAULT_LOCATION
|
||||
)
|
||||
logger.debug("# oci configuration file path: %s", self.oci_config_file)
|
||||
|
||||
response = self.auth_response(auth_data, **kwargs)
|
||||
if response is None:
|
||||
raise errors.InterfaceError("Got a NULL auth response")
|
||||
|
||||
logger.debug("# request: %s size: %s", response, len(response))
|
||||
await sock.write(response)
|
||||
|
||||
packet = await sock.read()
|
||||
logger.debug("# server response packet: %s", packet)
|
||||
|
||||
return bytes(packet)
|
||||
@@ -0,0 +1,172 @@
|
||||
# Copyright (c) 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""OpenID Authentication Plugin."""
|
||||
|
||||
import re
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, List
|
||||
|
||||
from mysql.connector import errors, utils
|
||||
from mysql.connector.aio.network import MySQLSocket
|
||||
from mysql.connector.logger import logger
|
||||
|
||||
from . import MySQLAuthPlugin
|
||||
|
||||
AUTHENTICATION_PLUGIN_CLASS = "MySQLOpenIDConnectAuthPlugin"
|
||||
OPENID_TOKEN_MAX_SIZE = 10 * 1024 # In bytes
|
||||
|
||||
|
||||
class MySQLOpenIDConnectAuthPlugin(MySQLAuthPlugin):
|
||||
"""Class implementing the MySQL OpenID Connect Authentication Plugin."""
|
||||
|
||||
_openid_capability_flag: bytes = utils.int1store(1)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Plugin official name."""
|
||||
return "authentication_openid_connect_client"
|
||||
|
||||
@property
|
||||
def requires_ssl(self) -> bool:
|
||||
"""Signals whether or not SSL is required."""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _validate_openid_token(token: str) -> bool:
|
||||
"""Helper method used to validate OpenID Connect token
|
||||
|
||||
The Token is represented as a JSON Web Token (JWT) consists of a
|
||||
base64-encoded header, body, and signature, separated by '.' e.g.,
|
||||
"Base64url.Base64url.Base64url". The First part of the token contains
|
||||
the header, the second part contains payload and the third part contains
|
||||
signature. These token parts should be Base64 URLSafe i.e., Token cannot
|
||||
contain characters other than a-z, A-Z, 0-9 and special characters '-', '_'.
|
||||
|
||||
Args:
|
||||
token (str): Base64url-encoded OpenID connect token fetched from
|
||||
the file path passed via `openid_token_file` connection
|
||||
argument.
|
||||
|
||||
Returns:
|
||||
bool: Signal indicating whether the token is valid or not.
|
||||
"""
|
||||
header_payload_sig: List[str] = token.split(".")
|
||||
if len(header_payload_sig) != 3:
|
||||
# invalid structure
|
||||
return False
|
||||
urlsafe_pattern = re.compile("^[a-zA-Z0-9-_]*$")
|
||||
return all(
|
||||
(
|
||||
len(token_part) and urlsafe_pattern.search(token_part) is not None
|
||||
for token_part in header_payload_sig
|
||||
)
|
||||
)
|
||||
|
||||
def auth_response(self, auth_data: bytes, **kwargs: Any) -> bytes:
|
||||
"""Prepares authentication string for the server.
|
||||
Args:
|
||||
auth_data: Authorization data.
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked.
|
||||
|
||||
Returns:
|
||||
packet: Client's authorization response.
|
||||
The OpenID Connect authorization response follows the pattern :-
|
||||
int<1> capability flag
|
||||
string<lenenc> id token
|
||||
|
||||
Raises:
|
||||
InterfaceError: If the connection is insecure or the OpenID Token is too large,
|
||||
invalid or non-existent.
|
||||
ProgrammingError: If the OpenID Token file could not be read.
|
||||
"""
|
||||
try:
|
||||
# Check if the connection is secure
|
||||
if self.requires_ssl and not self._ssl_enabled:
|
||||
raise errors.InterfaceError(f"{self.name} requires SSL")
|
||||
|
||||
# Validate the file
|
||||
token_file_path: str = kwargs.get("openid_token_file", None)
|
||||
openid_token_file: Path = Path(token_file_path)
|
||||
# Check if token exceeds the maximum size
|
||||
if openid_token_file.stat().st_size > OPENID_TOKEN_MAX_SIZE:
|
||||
raise errors.InterfaceError(
|
||||
"The OpenID Connect token file size is too large (> 10KB)"
|
||||
)
|
||||
openid_token: str = openid_token_file.read_text(encoding="utf-8")
|
||||
openid_token = openid_token.strip()
|
||||
# Validate the JWT Token
|
||||
if not self._validate_openid_token(openid_token):
|
||||
raise errors.InterfaceError("The OpenID Connect Token is invalid")
|
||||
|
||||
# build the auth_response packet
|
||||
auth_response: List[bytes] = [
|
||||
self._openid_capability_flag,
|
||||
utils.lc_int(len(openid_token)),
|
||||
openid_token.encode(),
|
||||
]
|
||||
return b"".join(auth_response)
|
||||
except (SyntaxError, TypeError, OSError, UnicodeError) as err:
|
||||
raise errors.ProgrammingError(
|
||||
"The OpenID Connect Token File (openid_token_file) could not be read"
|
||||
) from err
|
||||
|
||||
async def auth_switch_response(
|
||||
self, sock: MySQLSocket, auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth switch request` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Plugin provided data (extracted from a packet
|
||||
representing an `auth switch request` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth
|
||||
communication.
|
||||
|
||||
Raises:
|
||||
InterfaceError: If a NULL auth response is received from auth_response method.
|
||||
"""
|
||||
response = self.auth_response(auth_data, **kwargs)
|
||||
|
||||
if response is None:
|
||||
raise errors.InterfaceError("Got a NULL auth response")
|
||||
|
||||
logger.debug("# request: %s size: %s", response, len(response))
|
||||
await sock.write(response)
|
||||
|
||||
packet = await sock.read()
|
||||
logger.debug("# server response packet: %s", packet)
|
||||
|
||||
return bytes(packet)
|
||||
@@ -0,0 +1,291 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""WebAuthn Authentication Plugin."""
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
from mysql.connector import errors, utils
|
||||
|
||||
from ..logger import logger
|
||||
from . import MySQLAuthPlugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..network import MySQLSocket
|
||||
|
||||
try:
|
||||
from fido2.cbor import dump_bytes as cbor_dump_bytes
|
||||
from fido2.client import Fido2Client, UserInteraction
|
||||
from fido2.hid import CtapHidDevice
|
||||
from fido2.webauthn import PublicKeyCredentialRequestOptions
|
||||
except ImportError as import_err:
|
||||
raise errors.ProgrammingError(
|
||||
"Module fido2 is required for WebAuthn authentication mechanism but was "
|
||||
"not found. Unable to authenticate with the server"
|
||||
) from import_err
|
||||
|
||||
try:
|
||||
from fido2.pcsc import CtapPcscDevice
|
||||
|
||||
CTAP_PCSC_DEVICE_AVAILABLE = True
|
||||
except ModuleNotFoundError:
|
||||
CTAP_PCSC_DEVICE_AVAILABLE = False
|
||||
|
||||
|
||||
AUTHENTICATION_PLUGIN_CLASS = "MySQLWebAuthnAuthPlugin"
|
||||
|
||||
|
||||
class ClientInteraction(UserInteraction):
|
||||
"""Provides user interaction to the Client."""
|
||||
|
||||
def __init__(self, callback: Optional[Callable] = None):
|
||||
self.callback = callback
|
||||
self.msg = (
|
||||
"Please insert FIDO device and perform gesture action for authentication "
|
||||
"to complete."
|
||||
)
|
||||
|
||||
def prompt_up(self) -> None:
|
||||
"""Prompt message for the user interaction with the FIDO device."""
|
||||
if self.callback is None:
|
||||
print(self.msg)
|
||||
else:
|
||||
self.callback(self.msg)
|
||||
|
||||
|
||||
class MySQLWebAuthnAuthPlugin(MySQLAuthPlugin):
|
||||
"""Class implementing the MySQL WebAuthn authentication plugin."""
|
||||
|
||||
client: Optional[Fido2Client] = None
|
||||
callback: Optional[Callable] = None
|
||||
options: dict = {"rpId": None, "challenge": None, "allowCredentials": []}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Plugin official name."""
|
||||
return "authentication_webauthn_client"
|
||||
|
||||
@property
|
||||
def requires_ssl(self) -> bool:
|
||||
"""Signals whether or not SSL is required."""
|
||||
return False
|
||||
|
||||
def get_assertion_response(
|
||||
self, credential_id: Optional[bytearray] = None
|
||||
) -> bytes:
|
||||
"""Get assertion from authenticator and return the response.
|
||||
|
||||
Args:
|
||||
credential_id (Optional[bytearray]): The credential ID.
|
||||
|
||||
Returns:
|
||||
bytearray: The response packet with the data from the assertion.
|
||||
"""
|
||||
if self.client is None:
|
||||
raise errors.InterfaceError("No WebAuthn client found")
|
||||
|
||||
if credential_id is not None:
|
||||
# If credential_id is not None, it's because the FIDO device does not
|
||||
# support resident keys and the credential_id was requested from the server
|
||||
self.options["allowCredentials"] = [
|
||||
{
|
||||
"id": credential_id,
|
||||
"type": "public-key",
|
||||
}
|
||||
]
|
||||
|
||||
# Get assertion from authenticator
|
||||
assertion = self.client.get_assertion(
|
||||
PublicKeyCredentialRequestOptions.from_dict(self.options)
|
||||
)
|
||||
number_of_assertions = len(assertion.get_assertions())
|
||||
client_data_json = b""
|
||||
|
||||
# Build response packet
|
||||
#
|
||||
# Format:
|
||||
# int<1> 0x02 (2) status tag
|
||||
# int<lenenc> number of assertions length encoded number of assertions
|
||||
# string authenticator data variable length raw binary string
|
||||
# string signed challenge variable length raw binary string
|
||||
# ...
|
||||
# ...
|
||||
# string authenticator data variable length raw binary string
|
||||
# string signed challenge variable length raw binary string
|
||||
# string ClientDataJSON variable length raw binary string
|
||||
packet = utils.lc_int(2)
|
||||
packet += utils.lc_int(number_of_assertions)
|
||||
|
||||
# Add authenticator data and signed challenge for each assertion
|
||||
for i in range(number_of_assertions):
|
||||
assertion_response = assertion.get_response(i)
|
||||
|
||||
# string<lenenc> authenticator_data
|
||||
authenticator_data = cbor_dump_bytes(assertion_response.authenticator_data)
|
||||
|
||||
# string<lenenc> signed_challenge
|
||||
signature = assertion_response.signature
|
||||
|
||||
packet += utils.lc_int(len(authenticator_data))
|
||||
packet += authenticator_data
|
||||
packet += utils.lc_int(len(signature))
|
||||
packet += signature
|
||||
|
||||
# string<lenenc> client_data_json
|
||||
client_data_json = assertion_response.client_data
|
||||
|
||||
packet += utils.lc_int(len(client_data_json))
|
||||
packet += client_data_json
|
||||
|
||||
logger.debug("WebAuthn - payload response packet: %s", packet)
|
||||
return packet
|
||||
|
||||
def auth_response(self, auth_data: bytes, **kwargs: Any) -> Optional[bytes]:
|
||||
"""Find authenticator device and check if supports resident keys.
|
||||
|
||||
It also creates a Fido2Client using the relying party ID from the server.
|
||||
|
||||
Raises:
|
||||
InterfaceError: When the FIDO device is not found.
|
||||
|
||||
Returns:
|
||||
bytes: 2 if the authenticator supports resident keys else 1.
|
||||
"""
|
||||
try:
|
||||
packets, capability = utils.read_int(auth_data, 1)
|
||||
challenge, rp_id = utils.read_lc_string_list(packets)
|
||||
self.options["challenge"] = challenge
|
||||
self.options["rpId"] = rp_id.decode()
|
||||
logger.debug("WebAuthn - capability: %d", capability)
|
||||
logger.debug("WebAuthn - challenge: %s", self.options["challenge"])
|
||||
logger.debug("WebAuthn - relying party id: %s", self.options["rpId"])
|
||||
except ValueError as err:
|
||||
raise errors.InterfaceError(
|
||||
"Unable to parse MySQL WebAuthn authentication data"
|
||||
) from err
|
||||
|
||||
# Locate a device
|
||||
device = next(CtapHidDevice.list_devices(), None)
|
||||
if device is not None:
|
||||
logger.debug("WebAuthn - Use USB HID channel")
|
||||
elif CTAP_PCSC_DEVICE_AVAILABLE:
|
||||
device = next(CtapPcscDevice.list_devices(), None)
|
||||
|
||||
if device is None:
|
||||
raise errors.InterfaceError("No FIDO device found")
|
||||
|
||||
# Set up a FIDO 2 client using the origin relying party id
|
||||
self.client = Fido2Client(
|
||||
device,
|
||||
f"https://{self.options['rpId']}",
|
||||
user_interaction=ClientInteraction(self.callback),
|
||||
)
|
||||
|
||||
if not self.client.info.options.get("rk"):
|
||||
logger.debug("WebAuthn - Authenticator doesn't support resident keys")
|
||||
return b"1"
|
||||
|
||||
logger.debug("WebAuthn - Authenticator with support for resident key found")
|
||||
return b"2"
|
||||
|
||||
async def auth_more_response(
|
||||
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth more data` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Authentication method data (from a packet representing
|
||||
an `auth more data` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth
|
||||
communication.
|
||||
"""
|
||||
_, credential_id = utils.read_lc_string(auth_data)
|
||||
|
||||
response = self.get_assertion_response(credential_id)
|
||||
|
||||
logger.debug("WebAuthn - request: %s size: %s", response, len(response))
|
||||
await sock.write(response)
|
||||
|
||||
pkt = bytes(await sock.read())
|
||||
logger.debug("WebAuthn - server response packet: %s", pkt)
|
||||
|
||||
return pkt
|
||||
|
||||
async def auth_switch_response(
|
||||
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth switch request` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Plugin provided data (extracted from a packet
|
||||
representing an `auth switch request` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth
|
||||
communication.
|
||||
"""
|
||||
webauth_callback = kwargs.get("webauthn_callback") or kwargs.get(
|
||||
"fido_callback"
|
||||
)
|
||||
self.callback = (
|
||||
utils.import_object(webauth_callback)
|
||||
if isinstance(webauth_callback, str)
|
||||
else webauth_callback
|
||||
)
|
||||
|
||||
response = self.auth_response(auth_data)
|
||||
credential_id = None
|
||||
|
||||
if response == b"1":
|
||||
# Authenticator doesn't support resident keys, request credential_id
|
||||
logger.debug("WebAuthn - request credential_id")
|
||||
await sock.write(utils.lc_int(int(response)))
|
||||
|
||||
# return a packet representing an `auth more data` response
|
||||
return bytes(await sock.read())
|
||||
|
||||
response = self.get_assertion_response(credential_id)
|
||||
|
||||
logger.debug("WebAuthn - request: %s size: %s", response, len(response))
|
||||
await sock.write(response)
|
||||
|
||||
pkt = bytes(await sock.read())
|
||||
logger.debug("WebAuthn - server response packet: %s", pkt)
|
||||
|
||||
return pkt
|
||||
@@ -0,0 +1,160 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Caching SHA2 Password Authentication Plugin."""
|
||||
|
||||
import struct
|
||||
|
||||
from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from mysql.connector.errors import InterfaceError
|
||||
from mysql.connector.logger import logger
|
||||
|
||||
from . import MySQLAuthPlugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..network import MySQLSocket
|
||||
|
||||
AUTHENTICATION_PLUGIN_CLASS = "MySQLCachingSHA2PasswordAuthPlugin"
|
||||
|
||||
|
||||
class MySQLCachingSHA2PasswordAuthPlugin(MySQLAuthPlugin):
|
||||
"""Class implementing the MySQL caching_sha2_password authentication plugin
|
||||
|
||||
Note that encrypting using RSA is not supported since the Python
|
||||
Standard Library does not provide this OpenSSL functionality.
|
||||
"""
|
||||
|
||||
perform_full_authentication: int = 4
|
||||
|
||||
def _scramble(self, auth_data: bytes) -> bytes:
|
||||
"""Return a scramble of the password using a Nonce sent by the
|
||||
server.
|
||||
|
||||
The scramble is of the form:
|
||||
XOR(SHA2(password), SHA2(SHA2(SHA2(password)), Nonce))
|
||||
"""
|
||||
if not auth_data:
|
||||
raise InterfaceError("Missing authentication data (seed)")
|
||||
|
||||
if not self._password:
|
||||
return b""
|
||||
|
||||
hash1 = sha256(self._password.encode()).digest()
|
||||
hash2 = sha256()
|
||||
hash2.update(sha256(hash1).digest())
|
||||
hash2.update(auth_data)
|
||||
hash2_digest = hash2.digest()
|
||||
xored = [h1 ^ h2 for (h1, h2) in zip(hash1, hash2_digest)]
|
||||
hash3 = struct.pack("32B", *xored)
|
||||
return hash3
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Plugin official name."""
|
||||
return "caching_sha2_password"
|
||||
|
||||
@property
|
||||
def requires_ssl(self) -> bool:
|
||||
"""Signals whether or not SSL is required."""
|
||||
return False
|
||||
|
||||
def auth_response(self, auth_data: bytes, **kwargs: Any) -> Optional[bytes]:
|
||||
"""Make the client's authorization response.
|
||||
|
||||
Args:
|
||||
auth_data: Authorization data.
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Client's authorization response.
|
||||
"""
|
||||
if not auth_data:
|
||||
return None
|
||||
if len(auth_data) > 1:
|
||||
return self._scramble(auth_data)
|
||||
if auth_data[0] == self.perform_full_authentication:
|
||||
# return password as clear text.
|
||||
return self._password.encode() + b"\x00"
|
||||
|
||||
return None
|
||||
|
||||
async def auth_more_response(
|
||||
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth more data` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Authentication method data (from a packet representing
|
||||
an `auth more data` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth
|
||||
communication.
|
||||
"""
|
||||
response = self.auth_response(auth_data, **kwargs)
|
||||
if response:
|
||||
await sock.write(response)
|
||||
|
||||
return bytes(await sock.read())
|
||||
|
||||
async def auth_switch_response(
|
||||
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth switch request` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Plugin provided data (extracted from a packet
|
||||
representing an `auth switch request` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth
|
||||
communication.
|
||||
"""
|
||||
response = self.auth_response(auth_data, **kwargs)
|
||||
if response is None:
|
||||
raise InterfaceError("Got a NULL auth response")
|
||||
|
||||
logger.debug("# request: %s size: %s", response, len(response))
|
||||
await sock.write(response)
|
||||
|
||||
pkt = bytes(await sock.read())
|
||||
logger.debug("# server response packet: %s", pkt)
|
||||
|
||||
return pkt
|
||||
@@ -0,0 +1,105 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Clear Password Authentication Plugin."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from mysql.connector import errors
|
||||
from mysql.connector.logger import logger
|
||||
|
||||
from . import MySQLAuthPlugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..network import MySQLSocket
|
||||
|
||||
AUTHENTICATION_PLUGIN_CLASS = "MySQLClearPasswordAuthPlugin"
|
||||
|
||||
|
||||
class MySQLClearPasswordAuthPlugin(MySQLAuthPlugin):
|
||||
"""Class implementing the MySQL Clear Password authentication plugin"""
|
||||
|
||||
def _prepare_password(self) -> bytes:
|
||||
"""Prepare and return password as as clear text.
|
||||
|
||||
Returns:
|
||||
bytes: Prepared password.
|
||||
"""
|
||||
return self._password.encode() + b"\x00"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Plugin official name."""
|
||||
return "mysql_clear_password"
|
||||
|
||||
@property
|
||||
def requires_ssl(self) -> bool:
|
||||
"""Signals whether or not SSL is required."""
|
||||
return False
|
||||
|
||||
def auth_response(self, auth_data: bytes, **kwargs: Any) -> Optional[bytes]:
|
||||
"""Return the prepared password to send to MySQL.
|
||||
|
||||
Raises:
|
||||
InterfaceError: When SSL is required by not enabled.
|
||||
|
||||
Returns:
|
||||
str: The prepared password.
|
||||
"""
|
||||
if self.requires_ssl and not self._ssl_enabled:
|
||||
raise errors.InterfaceError(f"{self.name} requires SSL")
|
||||
return self._prepare_password()
|
||||
|
||||
async def auth_switch_response(
|
||||
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth switch request` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Plugin provided data (extracted from a packet
|
||||
representing an `auth switch request` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth
|
||||
communication.
|
||||
"""
|
||||
response = self.auth_response(auth_data, **kwargs)
|
||||
if response is None:
|
||||
raise errors.InterfaceError("Got a NULL auth response")
|
||||
|
||||
logger.debug("# request: %s size: %s", response, len(response))
|
||||
await sock.write(response)
|
||||
|
||||
pkt = bytes(await sock.read())
|
||||
logger.debug("# server response packet: %s", pkt)
|
||||
|
||||
return pkt
|
||||
@@ -0,0 +1,121 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Native Password Authentication Plugin."""
|
||||
|
||||
import struct
|
||||
|
||||
from hashlib import sha1
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from mysql.connector.errors import InterfaceError
|
||||
from mysql.connector.logger import logger
|
||||
|
||||
from . import MySQLAuthPlugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..network import MySQLSocket
|
||||
|
||||
AUTHENTICATION_PLUGIN_CLASS = "MySQLNativePasswordAuthPlugin"
|
||||
|
||||
|
||||
class MySQLNativePasswordAuthPlugin(MySQLAuthPlugin):
|
||||
"""Class implementing the MySQL Native Password authentication plugin"""
|
||||
|
||||
def _prepare_password(self, auth_data: bytes) -> bytes:
|
||||
"""Prepares and returns password as native MySQL 4.1+ password"""
|
||||
if not auth_data:
|
||||
raise InterfaceError("Missing authentication data (seed)")
|
||||
|
||||
if not self._password:
|
||||
return b""
|
||||
|
||||
hash4 = None
|
||||
try:
|
||||
hash1 = sha1(self._password.encode()).digest()
|
||||
hash2 = sha1(hash1).digest()
|
||||
hash3 = sha1(auth_data + hash2).digest()
|
||||
xored = [h1 ^ h3 for (h1, h3) in zip(hash1, hash3)]
|
||||
hash4 = struct.pack("20B", *xored)
|
||||
except (struct.error, TypeError) as err:
|
||||
raise InterfaceError(f"Failed scrambling password; {err}") from err
|
||||
|
||||
return hash4
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Plugin official name."""
|
||||
return "mysql_native_password"
|
||||
|
||||
@property
|
||||
def requires_ssl(self) -> bool:
|
||||
"""Signals whether or not SSL is required."""
|
||||
return False
|
||||
|
||||
def auth_response(self, auth_data: bytes, **kwargs: Any) -> Optional[bytes]:
|
||||
"""Make the client's authorization response.
|
||||
|
||||
Args:
|
||||
auth_data: Authorization data.
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Client's authorization response.
|
||||
"""
|
||||
return self._prepare_password(auth_data)
|
||||
|
||||
async def auth_switch_response(
|
||||
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth switch request` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Plugin provided data (extracted from a packet
|
||||
representing an `auth switch request` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth
|
||||
communication.
|
||||
"""
|
||||
response = self.auth_response(auth_data, **kwargs)
|
||||
if response is None:
|
||||
raise InterfaceError("Got a NULL auth response")
|
||||
|
||||
logger.debug("# request: %s size: %s", response, len(response))
|
||||
await sock.write(response)
|
||||
|
||||
pkt = bytes(await sock.read())
|
||||
logger.debug("# server response packet: %s", pkt)
|
||||
|
||||
return pkt
|
||||
@@ -0,0 +1,109 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""SHA256 Password Authentication Plugin."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from mysql.connector import errors
|
||||
from mysql.connector.logger import logger
|
||||
|
||||
from . import MySQLAuthPlugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..network import MySQLSocket
|
||||
|
||||
AUTHENTICATION_PLUGIN_CLASS = "MySQLSHA256PasswordAuthPlugin"
|
||||
|
||||
|
||||
class MySQLSHA256PasswordAuthPlugin(MySQLAuthPlugin):
|
||||
"""Class implementing the MySQL SHA256 authentication plugin
|
||||
|
||||
Note that encrypting using RSA is not supported since the Python
|
||||
Standard Library does not provide this OpenSSL functionality.
|
||||
"""
|
||||
|
||||
def _prepare_password(self) -> bytes:
|
||||
"""Prepare and return password as as clear text.
|
||||
|
||||
Returns:
|
||||
password (bytes): Prepared password.
|
||||
"""
|
||||
return self._password.encode() + b"\x00"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Plugin official name."""
|
||||
return "sha256_password"
|
||||
|
||||
@property
|
||||
def requires_ssl(self) -> bool:
|
||||
"""Signals whether or not SSL is required."""
|
||||
return True
|
||||
|
||||
def auth_response(self, auth_data: bytes, **kwargs: Any) -> Optional[bytes]:
|
||||
"""Return the prepared password to send to MySQL.
|
||||
|
||||
Raises:
|
||||
InterfaceError: When SSL is required by not enabled.
|
||||
|
||||
Returns:
|
||||
str: The prepared password.
|
||||
"""
|
||||
if self.requires_ssl and not self.ssl_enabled:
|
||||
raise errors.InterfaceError(f"{self.name} requires SSL")
|
||||
return self._prepare_password()
|
||||
|
||||
async def auth_switch_response(
|
||||
self, sock: "MySQLSocket", auth_data: bytes, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handles server's `auth switch request` response.
|
||||
|
||||
Args:
|
||||
sock: Pointer to the socket connection.
|
||||
auth_data: Plugin provided data (extracted from a packet
|
||||
representing an `auth switch request` response).
|
||||
kwargs: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override the ones
|
||||
defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
packet: Last server's response after back-and-forth
|
||||
communication.
|
||||
"""
|
||||
response = self.auth_response(auth_data, **kwargs)
|
||||
if response is None:
|
||||
raise errors.InterfaceError("Got a NULL auth response")
|
||||
|
||||
logger.debug("# request: %s size: %s", response, len(response))
|
||||
await sock.write(response)
|
||||
|
||||
pkt = bytes(await sock.read())
|
||||
logger.debug("# server response packet: %s", pkt)
|
||||
|
||||
return pkt
|
||||
@@ -0,0 +1,325 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
"""Implements the MySQL Client/Server protocol."""
|
||||
|
||||
__all__ = ["MySQLProtocol"]
|
||||
|
||||
import struct
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from ..constants import ClientFlag, ServerCmd
|
||||
from ..errors import InterfaceError, ProgrammingError, get_exception
|
||||
from ..logger import logger
|
||||
from ..protocol import (
|
||||
DEFAULT_CHARSET_ID,
|
||||
DEFAULT_MAX_ALLOWED_PACKET,
|
||||
MySQLProtocol as _MySQLProtocol,
|
||||
)
|
||||
from ..types import BinaryProtocolType, DescriptionType, EofPacketType, HandShakeType
|
||||
from ..utils import lc_int, read_lc_string_list
|
||||
from .network import MySQLSocket
|
||||
from .plugins import MySQLAuthPlugin, get_auth_plugin
|
||||
from .plugins.caching_sha2_password import MySQLCachingSHA2PasswordAuthPlugin
|
||||
|
||||
|
||||
class MySQLProtocol(_MySQLProtocol):
|
||||
"""Implements MySQL client/server protocol.
|
||||
|
||||
Create and parses MySQL packets.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def auth_plugin_first_response( # type: ignore[override]
|
||||
auth_data: bytes,
|
||||
username: str,
|
||||
password: str,
|
||||
auth_plugin: str,
|
||||
auth_plugin_class: Optional[str] = None,
|
||||
ssl_enabled: bool = False,
|
||||
plugin_config: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[bytes, MySQLAuthPlugin]:
|
||||
"""Prepare the first authentication response.
|
||||
|
||||
Args:
|
||||
auth_data: Authorization data from initial handshake.
|
||||
username: Account's username.
|
||||
password: Account's password.
|
||||
client_flags: Integer representing client capabilities flags.
|
||||
auth_plugin: Authorization plugin name.
|
||||
auth_plugin_class: Authorization plugin class (has higher precedence
|
||||
than the authorization plugin name).
|
||||
ssl_enabled: Whether SSL is enabled or not.
|
||||
plugin_config: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override
|
||||
the ones defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
auth_response: Authorization plugin response.
|
||||
auth_strategy: Authorization plugin instance created based
|
||||
on the provided `auth_plugin` and `auth_plugin_class`
|
||||
parameters.
|
||||
|
||||
Raises:
|
||||
InterfaceError: If authentication fails or when got a NULL auth response.
|
||||
"""
|
||||
if not password and auth_plugin == "":
|
||||
# return auth response and an arbitrary auth strategy
|
||||
return b"\x00", MySQLCachingSHA2PasswordAuthPlugin(
|
||||
username, password, ssl_enabled=ssl_enabled
|
||||
)
|
||||
|
||||
if plugin_config is None:
|
||||
plugin_config = {}
|
||||
|
||||
try:
|
||||
auth_strategy = get_auth_plugin(auth_plugin, auth_plugin_class)(
|
||||
username, password, ssl_enabled=ssl_enabled
|
||||
)
|
||||
auth_response = auth_strategy.auth_response(auth_data, **plugin_config)
|
||||
except (TypeError, InterfaceError) as err:
|
||||
raise InterfaceError(f"Failed authentication: {err}") from err
|
||||
|
||||
if auth_response is None:
|
||||
raise InterfaceError(
|
||||
"Got NULL auth response while authenticating with "
|
||||
f"plugin {auth_strategy.name}"
|
||||
)
|
||||
|
||||
auth_response = lc_int(len(auth_response)) + auth_response
|
||||
|
||||
return auth_response, auth_strategy
|
||||
|
||||
@staticmethod
|
||||
def make_auth( # type: ignore[override]
|
||||
handshake: HandShakeType,
|
||||
username: str,
|
||||
password: str,
|
||||
database: Optional[str] = None,
|
||||
charset: int = DEFAULT_CHARSET_ID,
|
||||
client_flags: int = 0,
|
||||
max_allowed_packet: int = DEFAULT_MAX_ALLOWED_PACKET,
|
||||
auth_plugin: Optional[str] = None,
|
||||
auth_plugin_class: Optional[str] = None,
|
||||
conn_attrs: Optional[Dict[str, str]] = None,
|
||||
is_change_user_request: bool = False,
|
||||
ssl_enabled: bool = False,
|
||||
plugin_config: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[bytes, MySQLAuthPlugin]:
|
||||
"""Make a MySQL Authentication packet.
|
||||
|
||||
Args:
|
||||
handshake: Initial handshake.
|
||||
username: Account's username.
|
||||
password: Account's password.
|
||||
database: Initial database name for the connection
|
||||
charset: Client charset (see [2]), only the lower 8-bits.
|
||||
client_flags: Integer representing client capabilities flags.
|
||||
max_allowed_packet: Maximum packet size.
|
||||
auth_plugin: Authorization plugin name.
|
||||
auth_plugin_class: Authorization plugin class (has higher precedence
|
||||
than the authorization plugin name).
|
||||
conn_attrs: Connection attributes.
|
||||
is_change_user_request: Whether is a `change user request` operation or not.
|
||||
ssl_enabled: Whether SSL is enabled or not.
|
||||
plugin_config: Custom configuration to be passed to the auth plugin
|
||||
when invoked. The parameters defined here will override
|
||||
the one defined in the auth plugin itself.
|
||||
|
||||
Returns:
|
||||
handshake_response: Handshake response as per [1].
|
||||
auth_strategy: Authorization plugin instance created based
|
||||
on the provided `auth_plugin` and `auth_plugin_class`.
|
||||
|
||||
Raises:
|
||||
ProgrammingError: Handshake misses authentication info.
|
||||
|
||||
References:
|
||||
[1]: https://dev.mysql.com/doc/dev/mysql-server/latest/\
|
||||
page_protocol_connection_phase_packets_protocol_handshake_response.html
|
||||
|
||||
[2]: https://dev.mysql.com/doc/dev/mysql-server/latest/\
|
||||
page_protocol_basic_character_set.html#a_protocol_character_set
|
||||
"""
|
||||
b_username = username.encode()
|
||||
response_payload = []
|
||||
|
||||
if is_change_user_request:
|
||||
logger.debug("Got a `change user` request")
|
||||
|
||||
logger.debug("Starting authorization phase")
|
||||
if handshake is None:
|
||||
raise ProgrammingError("Got a NULL handshake") from None
|
||||
|
||||
if handshake.get("auth_data") is None:
|
||||
raise ProgrammingError("Handshake misses authentication info") from None
|
||||
|
||||
try:
|
||||
auth_plugin = auth_plugin or handshake["auth_plugin"] # type: ignore[assignment]
|
||||
except (TypeError, KeyError) as err:
|
||||
raise ProgrammingError(
|
||||
f"Handshake misses authentication plugin info ({err})"
|
||||
) from None
|
||||
|
||||
logger.debug("The provided initial strategy is %s", auth_plugin)
|
||||
|
||||
if is_change_user_request:
|
||||
response_payload.append(
|
||||
struct.pack(
|
||||
f"<B{len(b_username)}sx",
|
||||
ServerCmd.CHANGE_USER,
|
||||
b_username,
|
||||
)
|
||||
)
|
||||
else:
|
||||
filler = "x" * 23
|
||||
response_payload.append(
|
||||
struct.pack(
|
||||
f"<IIB{filler}{len(b_username)}sx",
|
||||
client_flags,
|
||||
max_allowed_packet,
|
||||
charset,
|
||||
b_username,
|
||||
)
|
||||
)
|
||||
|
||||
# auth plugin response
|
||||
auth_response, auth_strategy = MySQLProtocol.auth_plugin_first_response(
|
||||
auth_data=handshake["auth_data"], # type: ignore[arg-type]
|
||||
username=username,
|
||||
password=password,
|
||||
auth_plugin=auth_plugin,
|
||||
auth_plugin_class=auth_plugin_class,
|
||||
ssl_enabled=ssl_enabled,
|
||||
plugin_config=plugin_config,
|
||||
)
|
||||
response_payload.append(auth_response)
|
||||
|
||||
# database name
|
||||
response_payload.append(MySQLProtocol.connect_with_db(client_flags, database))
|
||||
|
||||
# charset
|
||||
if is_change_user_request:
|
||||
response_payload.append(struct.pack("<H", charset))
|
||||
|
||||
# plugin name
|
||||
if client_flags & ClientFlag.PLUGIN_AUTH:
|
||||
response_payload.append(auth_plugin.encode() + b"\x00")
|
||||
|
||||
# connection attributes
|
||||
if (client_flags & ClientFlag.CONNECT_ARGS) and conn_attrs is not None:
|
||||
response_payload.append(MySQLProtocol.make_conn_attrs(conn_attrs))
|
||||
|
||||
return b"".join(response_payload), auth_strategy
|
||||
|
||||
# pylint: disable=invalid-overridden-method
|
||||
async def read_binary_result( # type: ignore[override]
|
||||
self,
|
||||
sock: MySQLSocket,
|
||||
columns: List[DescriptionType],
|
||||
count: int = 1,
|
||||
charset: str = "utf-8",
|
||||
read_timeout: Optional[int] = None,
|
||||
) -> Tuple[
|
||||
List[Tuple[BinaryProtocolType, ...]],
|
||||
Optional[EofPacketType],
|
||||
]:
|
||||
"""Read MySQL binary protocol result.
|
||||
|
||||
Reads all or given number of binary resultset rows from the socket.
|
||||
"""
|
||||
rows = []
|
||||
eof = None
|
||||
values = None
|
||||
i = 0
|
||||
while True:
|
||||
if eof or i == count:
|
||||
break
|
||||
packet = await sock.read(read_timeout)
|
||||
if packet[4] == 254:
|
||||
eof = self.parse_eof(packet)
|
||||
values = None
|
||||
elif packet[4] == 0:
|
||||
eof = None
|
||||
values = self._parse_binary_values(columns, packet[5:], charset)
|
||||
if eof is None and values is not None:
|
||||
rows.append(values)
|
||||
elif eof is None and values is None:
|
||||
raise get_exception(packet)
|
||||
i += 1
|
||||
return (rows, eof)
|
||||
|
||||
# pylint: disable=invalid-overridden-method
|
||||
async def read_text_result( # type: ignore[override]
|
||||
self,
|
||||
sock: MySQLSocket,
|
||||
version: Tuple[int, ...],
|
||||
count: int = 1,
|
||||
read_timeout: Optional[int] = None,
|
||||
) -> Tuple[
|
||||
List[Tuple[Optional[bytes], ...]],
|
||||
Optional[EofPacketType],
|
||||
]:
|
||||
"""Read MySQL text result.
|
||||
|
||||
Reads all or given number of rows from the socket.
|
||||
|
||||
Returns a tuple with 2 elements: a list with all rows and
|
||||
the EOF packet.
|
||||
"""
|
||||
# Keep unused 'version' for API backward compatibility
|
||||
_ = version
|
||||
rows = []
|
||||
eof = None
|
||||
rowdata = None
|
||||
i = 0
|
||||
while True:
|
||||
if eof or i == count:
|
||||
break
|
||||
packet = await sock.read(read_timeout)
|
||||
if packet.startswith(b"\xff\xff\xff"):
|
||||
datas = [packet[4:]]
|
||||
packet = await sock.read(read_timeout)
|
||||
while packet.startswith(b"\xff\xff\xff"):
|
||||
datas.append(packet[4:])
|
||||
packet = await sock.read(read_timeout)
|
||||
datas.append(packet[4:])
|
||||
rowdata = read_lc_string_list(b"".join(datas))
|
||||
elif packet[4] == 254 and packet[0] < 7:
|
||||
eof = self.parse_eof(packet)
|
||||
rowdata = None
|
||||
else:
|
||||
eof = None
|
||||
rowdata = read_lc_string_list(bytes(packet[4:]))
|
||||
if eof is None and rowdata is not None:
|
||||
rows.append(rowdata)
|
||||
elif eof is None and rowdata is None:
|
||||
raise get_exception(packet)
|
||||
i += 1
|
||||
return rows, eof
|
||||
199
myenv/lib/python3.11/site-packages/mysql/connector/aio/utils.py
Normal file
199
myenv/lib/python3.11/site-packages/mysql/connector/aio/utils.py
Normal file
@@ -0,0 +1,199 @@
|
||||
# Copyright (c) 2023, 2024, Oracle and/or its affiliates.
|
||||
#
|
||||
# This program is free software; you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License, version 2.0, as
|
||||
# published by the Free Software Foundation.
|
||||
#
|
||||
# This program is designed to work with certain software (including
|
||||
# but not limited to OpenSSL) that is licensed under separate terms,
|
||||
# as designated in a particular file or component or in included license
|
||||
# documentation. The authors of MySQL hereby grant you an
|
||||
# additional permission to link the program and your derivative works
|
||||
# with the separately licensed software that they have either included with
|
||||
# the program or referenced in the documentation.
|
||||
#
|
||||
# Without limiting anything contained in the foregoing, this file,
|
||||
# which is part of MySQL Connector/Python, is also subject to the
|
||||
# Universal FOSS Exception, version 1.0, a copy of which can be found at
|
||||
# http://oss.oracle.com/licenses/universal-foss-exception.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful, but
|
||||
# WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
|
||||
# See the GNU General Public License, version 2.0, for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program; if not, write to the Free Software Foundation, Inc.,
|
||||
# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
||||
|
||||
# mypy: disable-error-code="attr-defined"
|
||||
# pylint: disable=protected-access
|
||||
|
||||
"""Utilities."""
|
||||
|
||||
__all__ = ["to_thread", "open_connection"]
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import functools
|
||||
import warnings
|
||||
|
||||
from mysql.connector.errors import ReadTimeoutError, WriteTimeoutError
|
||||
|
||||
try:
|
||||
import ssl
|
||||
except ImportError:
|
||||
ssl = None
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Tuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mysql.connector.aio.abstracts import MySQLConnectionAbstract
|
||||
|
||||
__all__.append("StreamWriter")
|
||||
|
||||
|
||||
class StreamReaderProtocol(asyncio.StreamReaderProtocol):
|
||||
"""Extends asyncio.streams.StreamReaderProtocol for adding start_tls().
|
||||
|
||||
The ``start_tls()`` is based on ``asyncio.streams.StreamWriter`` introduced
|
||||
in Python 3.11. It provides the same functionality for older Python versions.
|
||||
"""
|
||||
|
||||
def _replace_writer(self, writer: asyncio.StreamWriter) -> None:
|
||||
"""Replace stream writer.
|
||||
|
||||
Args:
|
||||
writer: Stream Writer.
|
||||
"""
|
||||
transport = writer.transport
|
||||
self._stream_writer = writer
|
||||
self._transport = transport
|
||||
self._over_ssl = transport.get_extra_info("sslcontext") is not None
|
||||
|
||||
|
||||
class StreamWriter(asyncio.streams.StreamWriter):
|
||||
"""Extends asyncio.streams.StreamWriter for adding start_tls().
|
||||
|
||||
The ``start_tls()`` is based on ``asyncio.streams.StreamWriter`` introduced
|
||||
in Python 3.11. It provides the same functionality for older Python versions.
|
||||
"""
|
||||
|
||||
async def start_tls(
|
||||
self,
|
||||
ssl_context: ssl.SSLContext,
|
||||
*,
|
||||
server_hostname: str = None,
|
||||
ssl_handshake_timeout: int = None,
|
||||
) -> None:
|
||||
"""Upgrade an existing stream-based connection to TLS.
|
||||
|
||||
Args:
|
||||
ssl_context: Configured SSL context.
|
||||
server_hostname: Server host name.
|
||||
ssl_handshake_timeout: SSL handshake timeout.
|
||||
"""
|
||||
server_side = self._protocol._client_connected_cb is not None
|
||||
protocol = self._protocol
|
||||
await self.drain()
|
||||
new_transport = await self._loop.start_tls(
|
||||
# pylint: disable=access-member-before-definition
|
||||
self._transport, # type: ignore[has-type]
|
||||
protocol,
|
||||
ssl_context,
|
||||
server_side=server_side,
|
||||
server_hostname=server_hostname,
|
||||
ssl_handshake_timeout=ssl_handshake_timeout,
|
||||
)
|
||||
self._transport = ( # pylint: disable=attribute-defined-outside-init
|
||||
new_transport
|
||||
)
|
||||
protocol._replace_writer(self)
|
||||
|
||||
|
||||
async def open_connection(
|
||||
host: str = None, port: int = None, *, limit: int = 2**16, **kwds: Any
|
||||
) -> Tuple[asyncio.StreamReader, StreamWriter]:
|
||||
"""A wrapper for create_connection() returning a (reader, writer) pair.
|
||||
|
||||
This function is based on ``asyncio.streams.open_connection`` and adds a custom
|
||||
stream reader.
|
||||
|
||||
MySQL expects TLS negotiation to happen in the middle of a TCP connection, not at
|
||||
the start.
|
||||
This function in conjunction with ``_StreamReaderProtocol`` and ``_StreamWriter``
|
||||
allows the TLS negotiation on an existing connection.
|
||||
|
||||
Args:
|
||||
host: Server host name.
|
||||
port: Server port.
|
||||
limit: The buffer size limit used by the returned ``StreamReader`` instance.
|
||||
By default the limit is set to 64 KiB.
|
||||
|
||||
Returns:
|
||||
tuple: Returns a pair of reader and writer objects that are instances of
|
||||
``StreamReader`` and ``StreamWriter`` classes.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
reader = asyncio.streams.StreamReader(limit=limit, loop=loop)
|
||||
protocol = StreamReaderProtocol(reader, loop=loop)
|
||||
transport, _ = await loop.create_connection(lambda: protocol, host, port, **kwds)
|
||||
writer = StreamWriter(transport, protocol, reader, loop)
|
||||
return reader, writer
|
||||
|
||||
|
||||
async def to_thread(func: Callable, *args: Any, **kwargs: Any) -> asyncio.Future:
|
||||
"""Asynchronously run function ``func`` in a separate thread.
|
||||
|
||||
This function is based on ``asyncio.to_thread()`` introduced in Python 3.9, which
|
||||
provides the same functionality for older Python versions.
|
||||
|
||||
Returns:
|
||||
coroutine: A coroutine that can be awaited to get the eventual result of
|
||||
``func``.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
ctx = contextvars.copy_context()
|
||||
func_call = functools.partial(ctx.run, func, *args, **kwargs)
|
||||
return await loop.run_in_executor(None, func_call)
|
||||
|
||||
|
||||
def deprecated(reason: str) -> Callable:
|
||||
"""Use it to decorate deprecated methods."""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Callable:
|
||||
warnings.warn(
|
||||
f"Call to deprecated function {func.__name__}. Reason: {reason}",
|
||||
category=DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def handle_read_write_timeout() -> Callable:
|
||||
"""
|
||||
Decorator to close the current connection if a read or a write timeout
|
||||
is raised by the method passed via the func parameter.
|
||||
"""
|
||||
|
||||
def decorator(cnx_method: Callable) -> Callable:
|
||||
@functools.wraps(cnx_method)
|
||||
async def handle_cnx_method(
|
||||
cnx: "MySQLConnectionAbstract", *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
try:
|
||||
return await cnx_method(cnx, *args, **kwargs)
|
||||
except Exception as err:
|
||||
if isinstance(err, (ReadTimeoutError, WriteTimeoutError)):
|
||||
await cnx.close()
|
||||
raise err
|
||||
|
||||
return handle_cnx_method
|
||||
|
||||
return decorator
|
||||
Reference in New Issue
Block a user