Module dlr.fl.client.client_server¶
View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0
from http.server import BaseHTTPRequestHandler
import json
import logging
from typing import Any
from uuid import UUID
from . import Communication
class ClientServerHandler(BaseHTTPRequestHandler):
"""
Default client server handler.
The handler will receive notifications from the server and forward them to the client communication module.
"""
_logger = logging.getLogger("fl.client")
"""Logger instance for the client server handler."""
def do_POST(self):
"""
Receive a notification from the server.
All notifications are expected to be in JSON format and contain the following fields:
- `notification_type`: The type of the notification.
- `training_uuid`: The UUID of the corresponding training.
- `body`: The notification body.
If an error occurs during the handling of the notification, the error will be logged
and a 500 response is sent.
"""
try:
content_len = int(self.headers.get("Content-Length"))
request_content = self.rfile.read(content_len)
content = json.loads(request_content)
self._logger.debug("notification received: " + json.dumps(content))
self.handle_message(
str(content["notification_type"]),
UUID(content["training_uuid"]),
content["body"],
)
except Exception as e:
self._logger.fatal(e)
self.send_response(500)
self.end_headers()
def handle_message(self, notification_type: str, training_id: UUID, data: Any):
"""
Forward the notification to the corresponding client communication module function.
Args:
notification_type (str): type of the notification
training_id (UUID): UUID of the corresponding training
data (Any): notification body
"""
self._logger.info(f"receive notification '{notification_type}' for training '{training_id}'")
match notification_type:
case "TRAINING_START":
Communication.init_training(
training_id=training_id,
model_id=UUID(data["global_model_uuid"]),
)
self.send_response(200)
case "UPDATE_ROUND_START":
Communication.start_training(
training_id=training_id,
round=int(data["round"]),
model_id=UUID(data["global_model_uuid"]),
)
self.send_response(202)
case "MODEL_TEST_ROUND":
Communication.start_testing(
training_id=training_id,
round=int(data["round"]),
model_id=UUID(data["global_model_uuid"]),
)
self.send_response(202)
case "TRAINING_FINISHED":
Communication.end_training(
training_id=training_id,
model_id=UUID(data["global_model_uuid"]),
)
self.send_response(200)
case _:
status_code = Communication.unknown_message(
mesage_type=data["notification_type"],
training_id=training_id,
data=data["data"],
)
self.send_response(status_code)
Classes¶
ClientServerHandler¶
Default client server handler.
The handler will receive notifications from the server and forward them to the client communication module.
View Source
class ClientServerHandler(BaseHTTPRequestHandler):
"""
Default client server handler.
The handler will receive notifications from the server and forward them to the client communication module.
"""
_logger = logging.getLogger("fl.client")
"""Logger instance for the client server handler."""
def do_POST(self):
"""
Receive a notification from the server.
All notifications are expected to be in JSON format and contain the following fields:
- `notification_type`: The type of the notification.
- `training_uuid`: The UUID of the corresponding training.
- `body`: The notification body.
If an error occurs during the handling of the notification, the error will be logged
and a 500 response is sent.
"""
try:
content_len = int(self.headers.get("Content-Length"))
request_content = self.rfile.read(content_len)
content = json.loads(request_content)
self._logger.debug("notification received: " + json.dumps(content))
self.handle_message(
str(content["notification_type"]),
UUID(content["training_uuid"]),
content["body"],
)
except Exception as e:
self._logger.fatal(e)
self.send_response(500)
self.end_headers()
def handle_message(self, notification_type: str, training_id: UUID, data: Any):
"""
Forward the notification to the corresponding client communication module function.
Args:
notification_type (str): type of the notification
training_id (UUID): UUID of the corresponding training
data (Any): notification body
"""
self._logger.info(f"receive notification '{notification_type}' for training '{training_id}'")
match notification_type:
case "TRAINING_START":
Communication.init_training(
training_id=training_id,
model_id=UUID(data["global_model_uuid"]),
)
self.send_response(200)
case "UPDATE_ROUND_START":
Communication.start_training(
training_id=training_id,
round=int(data["round"]),
model_id=UUID(data["global_model_uuid"]),
)
self.send_response(202)
case "MODEL_TEST_ROUND":
Communication.start_testing(
training_id=training_id,
round=int(data["round"]),
model_id=UUID(data["global_model_uuid"]),
)
self.send_response(202)
case "TRAINING_FINISHED":
Communication.end_training(
training_id=training_id,
model_id=UUID(data["global_model_uuid"]),
)
self.send_response(200)
case _:
status_code = Communication.unknown_message(
mesage_type=data["notification_type"],
training_id=training_id,
data=data["data"],
)
self.send_response(status_code)
Ancestors (in MRO)¶
- http.server.BaseHTTPRequestHandler
- socketserver.StreamRequestHandler
- socketserver.BaseRequestHandler
Class variables¶
Methods¶
address_string¶
Return the client address.
View Source
date_time_string¶
Return the current date and time formatted for a message header.
View Source
do_POST¶
Receive a notification from the server.
All notifications are expected to be in JSON format and contain the following fields:
notification_type
: The type of the notification.training_uuid
: The UUID of the corresponding training.body
: The notification body.
If an error occurs during the handling of the notification, the error will be logged and a 500 response is sent.
View Source
def do_POST(self):
"""
Receive a notification from the server.
All notifications are expected to be in JSON format and contain the following fields:
- `notification_type`: The type of the notification.
- `training_uuid`: The UUID of the corresponding training.
- `body`: The notification body.
If an error occurs during the handling of the notification, the error will be logged
and a 500 response is sent.
"""
try:
content_len = int(self.headers.get("Content-Length"))
request_content = self.rfile.read(content_len)
content = json.loads(request_content)
self._logger.debug("notification received: " + json.dumps(content))
self.handle_message(
str(content["notification_type"]),
UUID(content["training_uuid"]),
content["body"],
)
except Exception as e:
self._logger.fatal(e)
self.send_response(500)
self.end_headers()
end_headers¶
Send the blank line ending the MIME headers.
View Source
finish¶
View Source
flush_headers¶
View Source
handle¶
Handle multiple requests if necessary.
View Source
handle_expect_100¶
Decide what to do with an "Expect: 100-continue" header.
If the client is expecting a 100 Continue response, we must respond with either a 100 Continue or a final response before waiting for the request body. The default is to always respond with a 100 Continue. You can behave differently (for example, reject unauthorized requests) by overriding this method.
This method should either return True (possibly after sending a 100 Continue response) or send an error response and return False.
View Source
def handle_expect_100(self):
"""Decide what to do with an "Expect: 100-continue" header.
If the client is expecting a 100 Continue response, we must
respond with either a 100 Continue or a final response before
waiting for the request body. The default is to always respond
with a 100 Continue. You can behave differently (for example,
reject unauthorized requests) by overriding this method.
This method should either return True (possibly after sending
a 100 Continue response) or send an error response and return
False.
"""
self.send_response_only(HTTPStatus.CONTINUE)
self.end_headers()
return True
handle_message¶
Forward the notification to the corresponding client communication module function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
notification_type | str | type of the notification | None |
training_id | UUID | UUID of the corresponding training | None |
data | Any | notification body | None |
View Source
def handle_message(self, notification_type: str, training_id: UUID, data: Any):
"""
Forward the notification to the corresponding client communication module function.
Args:
notification_type (str): type of the notification
training_id (UUID): UUID of the corresponding training
data (Any): notification body
"""
self._logger.info(f"receive notification '{notification_type}' for training '{training_id}'")
match notification_type:
case "TRAINING_START":
Communication.init_training(
training_id=training_id,
model_id=UUID(data["global_model_uuid"]),
)
self.send_response(200)
case "UPDATE_ROUND_START":
Communication.start_training(
training_id=training_id,
round=int(data["round"]),
model_id=UUID(data["global_model_uuid"]),
)
self.send_response(202)
case "MODEL_TEST_ROUND":
Communication.start_testing(
training_id=training_id,
round=int(data["round"]),
model_id=UUID(data["global_model_uuid"]),
)
self.send_response(202)
case "TRAINING_FINISHED":
Communication.end_training(
training_id=training_id,
model_id=UUID(data["global_model_uuid"]),
)
self.send_response(200)
case _:
status_code = Communication.unknown_message(
mesage_type=data["notification_type"],
training_id=training_id,
data=data["data"],
)
self.send_response(status_code)
handle_one_request¶
Handle a single HTTP request.
You normally don't need to override this method; see the class doc string for information on how to handle specific HTTP commands such as GET and POST.
View Source
def handle_one_request(self):
"""Handle a single HTTP request.
You normally don't need to override this method; see the class
__doc__ string for information on how to handle specific HTTP
commands such as GET and POST.
"""
try:
self.raw_requestline = self.rfile.readline(65537)
if len(self.raw_requestline) > 65536:
self.requestline = ''
self.request_version = ''
self.command = ''
self.send_error(HTTPStatus.REQUEST_URI_TOO_LONG)
return
if not self.raw_requestline:
self.close_connection = True
return
if not self.parse_request():
# An error code has been sent, just exit
return
mname = 'do_' + self.command
if not hasattr(self, mname):
self.send_error(
HTTPStatus.NOT_IMPLEMENTED,
"Unsupported method (%r)" % self.command)
return
method = getattr(self, mname)
method()
self.wfile.flush() #actually send the response if not already done.
except TimeoutError as e:
#a read or a write timed out. Discard this connection
self.log_error("Request timed out: %r", e)
self.close_connection = True
return
log_date_time_string¶
Return the current time formatted for logging.
View Source
log_error¶
Log an error.
This is called when a request cannot be fulfilled. By default it passes the message on to log_message().
Arguments are the same as for log_message().
XXX This should go to the separate error log.
View Source
log_message¶
Log an arbitrary message.
This is used by all other logging functions. Override it if you have specific logging wishes.
The first argument, FORMAT, is a format string for the message to be logged. If the format string contains any % escapes requiring parameters, they should be specified as subsequent arguments (it's just like printf!).
The client ip and current date/time are prefixed to every message.
Unicode control characters are replaced with escaped hex before writing the output to stderr.
View Source
def log_message(self, format, *args):
"""Log an arbitrary message.
This is used by all other logging functions. Override
it if you have specific logging wishes.
The first argument, FORMAT, is a format string for the
message to be logged. If the format string contains
any % escapes requiring parameters, they should be
specified as subsequent arguments (it's just like
printf!).
The client ip and current date/time are prefixed to
every message.
Unicode control characters are replaced with escaped hex
before writing the output to stderr.
"""
message = format % args
sys.stderr.write("%s - - [%s] %s\n" %
(self.address_string(),
self.log_date_time_string(),
message.translate(self._control_char_table)))
log_request¶
Log an accepted request.
This is called by send_response().
View Source
parse_request¶
Parse a request (internal).
The request should be stored in self.raw_requestline; the results are in self.command, self.path, self.request_version and self.headers.
Return True for success, False for failure; on failure, any relevant error response has already been sent back.
View Source
def parse_request(self):
"""Parse a request (internal).
The request should be stored in self.raw_requestline; the results
are in self.command, self.path, self.request_version and
self.headers.
Return True for success, False for failure; on failure, any relevant
error response has already been sent back.
"""
self.command = None # set in case of error on the first line
self.request_version = version = self.default_request_version
self.close_connection = True
requestline = str(self.raw_requestline, 'iso-8859-1')
requestline = requestline.rstrip('\r\n')
self.requestline = requestline
words = requestline.split()
if len(words) == 0:
return False
if len(words) >= 3: # Enough to determine protocol version
version = words[-1]
try:
if not version.startswith('HTTP/'):
raise ValueError
base_version_number = version.split('/', 1)[1]
version_number = base_version_number.split(".")
# RFC 2145 section 3.1 says there can be only one "." and
# - major and minor numbers MUST be treated as
# separate integers;
# - HTTP/2.4 is a lower version than HTTP/2.13, which in
# turn is lower than HTTP/12.3;
# - Leading zeros MUST be ignored by recipients.
if len(version_number) != 2:
raise ValueError
version_number = int(version_number[0]), int(version_number[1])
except (ValueError, IndexError):
self.send_error(
HTTPStatus.BAD_REQUEST,
"Bad request version (%r)" % version)
return False
if version_number >= (1, 1) and self.protocol_version >= "HTTP/1.1":
self.close_connection = False
if version_number >= (2, 0):
self.send_error(
HTTPStatus.HTTP_VERSION_NOT_SUPPORTED,
"Invalid HTTP version (%s)" % base_version_number)
return False
self.request_version = version
if not 2 <= len(words) <= 3:
self.send_error(
HTTPStatus.BAD_REQUEST,
"Bad request syntax (%r)" % requestline)
return False
command, path = words[:2]
if len(words) == 2:
self.close_connection = True
if command != 'GET':
self.send_error(
HTTPStatus.BAD_REQUEST,
"Bad HTTP/0.9 request type (%r)" % command)
return False
self.command, self.path = command, path
# gh-87389: The purpose of replacing '//' with '/' is to protect
# against open redirect attacks possibly triggered if the path starts
# with '//' because http clients treat //path as an absolute URI
# without scheme (similar to http://path) rather than a path.
if self.path.startswith('//'):
self.path = '/' + self.path.lstrip('/') # Reduce to a single /
# Examine the headers and look for a Connection directive.
try:
self.headers = http.client.parse_headers(self.rfile,
_class=self.MessageClass)
except http.client.LineTooLong as err:
self.send_error(
HTTPStatus.REQUEST_HEADER_FIELDS_TOO_LARGE,
"Line too long",
str(err))
return False
except http.client.HTTPException as err:
self.send_error(
HTTPStatus.REQUEST_HEADER_FIELDS_TOO_LARGE,
"Too many headers",
str(err)
)
return False
conntype = self.headers.get('Connection', "")
if conntype.lower() == 'close':
self.close_connection = True
elif (conntype.lower() == 'keep-alive' and
self.protocol_version >= "HTTP/1.1"):
self.close_connection = False
# Examine the headers and look for an Expect directive
expect = self.headers.get('Expect', "")
if (expect.lower() == "100-continue" and
self.protocol_version >= "HTTP/1.1" and
self.request_version >= "HTTP/1.1"):
if not self.handle_expect_100():
return False
return True
send_error¶
Send and log an error reply.
Arguments are * code: an HTTP error code 3 digits * message: a simple optional 1 line reason phrase. *( HTAB / SP / VCHAR / %x80-FF ) defaults to short entry matching the response code * explain: a detailed message defaults to the long entry matching the response code.
This sends an error response (so it must be called before any output has been generated), logs the error, and finally sends a piece of HTML explaining the error to the user.
View Source
def send_error(self, code, message=None, explain=None):
"""Send and log an error reply.
Arguments are
* code: an HTTP error code
3 digits
* message: a simple optional 1 line reason phrase.
*( HTAB / SP / VCHAR / %x80-FF )
defaults to short entry matching the response code
* explain: a detailed message defaults to the long entry
matching the response code.
This sends an error response (so it must be called before any
output has been generated), logs the error, and finally sends
a piece of HTML explaining the error to the user.
"""
try:
shortmsg, longmsg = self.responses[code]
except KeyError:
shortmsg, longmsg = '???', '???'
if message is None:
message = shortmsg
if explain is None:
explain = longmsg
self.log_error("code %d, message %s", code, message)
self.send_response(code, message)
self.send_header('Connection', 'close')
# Message body is omitted for cases described in:
# - RFC7230: 3.3. 1xx, 204(No Content), 304(Not Modified)
# - RFC7231: 6.3.6. 205(Reset Content)
body = None
if (code >= 200 and
code not in (HTTPStatus.NO_CONTENT,
HTTPStatus.RESET_CONTENT,
HTTPStatus.NOT_MODIFIED)):
# HTML encode to prevent Cross Site Scripting attacks
# (see bug #1100201)
content = (self.error_message_format % {
'code': code,
'message': html.escape(message, quote=False),
'explain': html.escape(explain, quote=False)
})
body = content.encode('UTF-8', 'replace')
self.send_header("Content-Type", self.error_content_type)
self.send_header('Content-Length', str(len(body)))
self.end_headers()
if self.command != 'HEAD' and body:
self.wfile.write(body)
send_header¶
Send a MIME header to the headers buffer.
View Source
def send_header(self, keyword, value):
"""Send a MIME header to the headers buffer."""
if self.request_version != 'HTTP/0.9':
if not hasattr(self, '_headers_buffer'):
self._headers_buffer = []
self._headers_buffer.append(
("%s: %s\r\n" % (keyword, value)).encode('latin-1', 'strict'))
if keyword.lower() == 'connection':
if value.lower() == 'close':
self.close_connection = True
elif value.lower() == 'keep-alive':
self.close_connection = False
send_response¶
Add the response header to the headers buffer and log the
response code.
Also send two standard headers with the server software version and the current date.
View Source
def send_response(self, code, message=None):
"""Add the response header to the headers buffer and log the
response code.
Also send two standard headers with the server software
version and the current date.
"""
self.log_request(code)
self.send_response_only(code, message)
self.send_header('Server', self.version_string())
self.send_header('Date', self.date_time_string())
send_response_only¶
Send the response header only.
View Source
def send_response_only(self, code, message=None):
"""Send the response header only."""
if self.request_version != 'HTTP/0.9':
if message is None:
if code in self.responses:
message = self.responses[code][0]
else:
message = ''
if not hasattr(self, '_headers_buffer'):
self._headers_buffer = []
self._headers_buffer.append(("%s %d %s\r\n" %
(self.protocol_version, code, message)).encode(
'latin-1', 'strict'))
setup¶
View Source
def setup(self):
self.connection = self.request
if self.timeout is not None:
self.connection.settimeout(self.timeout)
if self.disable_nagle_algorithm:
self.connection.setsockopt(socket.IPPROTO_TCP,
socket.TCP_NODELAY, True)
self.rfile = self.connection.makefile('rb', self.rbufsize)
if self.wbufsize == 0:
self.wfile = _SocketWriter(self.connection)
else:
self.wfile = self.connection.makefile('wb', self.wbufsize)
version_string¶
Return the server software version string.