Skip to content

Module dlr.fl.client.client_server

View Source
# SPDX-FileCopyrightText: 2024 Benedikt Franke <benedikt.franke@dlr.de>
# SPDX-FileCopyrightText: 2024 Florian Heinrich <florian.heinrich@dlr.de>
#
# SPDX-License-Identifier: Apache-2.0

from http.server import BaseHTTPRequestHandler
import json
import logging
from typing import Any
from uuid import UUID

from . import Communication


class ClientServerHandler(BaseHTTPRequestHandler):
    """
    Default client server handler.

    The handler will receive notifications from the server and forward them to the client communication module.
    """

    _logger = logging.getLogger("fl.client")
    """Logger instance for the client server handler."""

    def do_POST(self):
        """
        Receive a notification from the server.

        All notifications are expected to be in JSON format and contain the following fields:

        - `notification_type`: The type of the notification.
        - `training_uuid`: The UUID of the corresponding training.
        - `body`: The notification body.

        If an error occurs during the handling of the notification, the error will be logged
        and a 500 response is sent.
        """
        try:
            content_len = int(self.headers.get("Content-Length"))
            request_content = self.rfile.read(content_len)
            content = json.loads(request_content)
            self._logger.debug("notification received: " + json.dumps(content))
            self.handle_message(
                str(content["notification_type"]),
                UUID(content["training_uuid"]),
                content["body"],
            )
        except Exception as e:
            self._logger.fatal(e)
            self.send_response(500)
        self.end_headers()

    def handle_message(self, notification_type: str, training_id: UUID, data: Any):
        """
        Forward the notification to the corresponding client communication module function.

        Args:
            notification_type (str): type of the notification
            training_id (UUID): UUID of the corresponding training
            data (Any): notification body
        """
        self._logger.info(f"receive notification '{notification_type}' for training '{training_id}'")
        match notification_type:
            case "TRAINING_START":
                Communication.init_training(
                    training_id=training_id,
                    model_id=UUID(data["global_model_uuid"]),
                )
                self.send_response(200)
            case "UPDATE_ROUND_START":
                Communication.start_training(
                    training_id=training_id,
                    round=int(data["round"]),
                    model_id=UUID(data["global_model_uuid"]),
                )
                self.send_response(202)
            case "MODEL_TEST_ROUND":
                Communication.start_testing(
                    training_id=training_id,
                    round=int(data["round"]),
                    model_id=UUID(data["global_model_uuid"]),
                )
                self.send_response(202)
            case "TRAINING_FINISHED":
                Communication.end_training(
                    training_id=training_id,
                    model_id=UUID(data["global_model_uuid"]),
                )
                self.send_response(200)
            case _:
                status_code = Communication.unknown_message(
                    mesage_type=data["notification_type"],
                    training_id=training_id,
                    data=data["data"],
                )
                self.send_response(status_code)

Classes

ClientServerHandler

class ClientServerHandler(
    request,
    client_address,
    server
)

Default client server handler.

The handler will receive notifications from the server and forward them to the client communication module.

View Source
class ClientServerHandler(BaseHTTPRequestHandler):
    """
    Default client server handler.

    The handler will receive notifications from the server and forward them to the client communication module.
    """

    _logger = logging.getLogger("fl.client")
    """Logger instance for the client server handler."""

    def do_POST(self):
        """
        Receive a notification from the server.

        All notifications are expected to be in JSON format and contain the following fields:

        - `notification_type`: The type of the notification.
        - `training_uuid`: The UUID of the corresponding training.
        - `body`: The notification body.

        If an error occurs during the handling of the notification, the error will be logged
        and a 500 response is sent.
        """
        try:
            content_len = int(self.headers.get("Content-Length"))
            request_content = self.rfile.read(content_len)
            content = json.loads(request_content)
            self._logger.debug("notification received: " + json.dumps(content))
            self.handle_message(
                str(content["notification_type"]),
                UUID(content["training_uuid"]),
                content["body"],
            )
        except Exception as e:
            self._logger.fatal(e)
            self.send_response(500)
        self.end_headers()

    def handle_message(self, notification_type: str, training_id: UUID, data: Any):
        """
        Forward the notification to the corresponding client communication module function.

        Args:
            notification_type (str): type of the notification
            training_id (UUID): UUID of the corresponding training
            data (Any): notification body
        """
        self._logger.info(f"receive notification '{notification_type}' for training '{training_id}'")
        match notification_type:
            case "TRAINING_START":
                Communication.init_training(
                    training_id=training_id,
                    model_id=UUID(data["global_model_uuid"]),
                )
                self.send_response(200)
            case "UPDATE_ROUND_START":
                Communication.start_training(
                    training_id=training_id,
                    round=int(data["round"]),
                    model_id=UUID(data["global_model_uuid"]),
                )
                self.send_response(202)
            case "MODEL_TEST_ROUND":
                Communication.start_testing(
                    training_id=training_id,
                    round=int(data["round"]),
                    model_id=UUID(data["global_model_uuid"]),
                )
                self.send_response(202)
            case "TRAINING_FINISHED":
                Communication.end_training(
                    training_id=training_id,
                    model_id=UUID(data["global_model_uuid"]),
                )
                self.send_response(200)
            case _:
                status_code = Communication.unknown_message(
                    mesage_type=data["notification_type"],
                    training_id=training_id,
                    data=data["data"],
                )
                self.send_response(status_code)

Ancestors (in MRO)

  • http.server.BaseHTTPRequestHandler
  • socketserver.StreamRequestHandler
  • socketserver.BaseRequestHandler

Class variables

MessageClass
default_request_version
disable_nagle_algorithm
error_content_type
error_message_format
monthname
protocol_version
rbufsize
responses
server_version
sys_version
timeout
wbufsize
weekdayname

Methods

address_string

def address_string(
    self
)

Return the client address.

View Source
    def address_string(self):
        """Return the client address."""

        return self.client_address[0]

date_time_string

def date_time_string(
    self,
    timestamp=None
)

Return the current date and time formatted for a message header.

View Source
    def date_time_string(self, timestamp=None):
        """Return the current date and time formatted for a message header."""
        if timestamp is None:
            timestamp = time.time()
        return email.utils.formatdate(timestamp, usegmt=True)

do_POST

def do_POST(
    self
)

Receive a notification from the server.

All notifications are expected to be in JSON format and contain the following fields:

  • notification_type: The type of the notification.
  • training_uuid: The UUID of the corresponding training.
  • body: The notification body.

If an error occurs during the handling of the notification, the error will be logged and a 500 response is sent.

View Source
    def do_POST(self):
        """
        Receive a notification from the server.

        All notifications are expected to be in JSON format and contain the following fields:

        - `notification_type`: The type of the notification.
        - `training_uuid`: The UUID of the corresponding training.
        - `body`: The notification body.

        If an error occurs during the handling of the notification, the error will be logged
        and a 500 response is sent.
        """
        try:
            content_len = int(self.headers.get("Content-Length"))
            request_content = self.rfile.read(content_len)
            content = json.loads(request_content)
            self._logger.debug("notification received: " + json.dumps(content))
            self.handle_message(
                str(content["notification_type"]),
                UUID(content["training_uuid"]),
                content["body"],
            )
        except Exception as e:
            self._logger.fatal(e)
            self.send_response(500)
        self.end_headers()

end_headers

def end_headers(
    self
)

Send the blank line ending the MIME headers.

View Source
    def end_headers(self):
        """Send the blank line ending the MIME headers."""
        if self.request_version != 'HTTP/0.9':
            self._headers_buffer.append(b"\r\n")
            self.flush_headers()

finish

def finish(
    self
)
View Source
    def finish(self):
        if not self.wfile.closed:
            try:
                self.wfile.flush()
            except socket.error:
                # A final socket error may have occurred here, such as
                # the local error ECONNABORTED.
                pass
        self.wfile.close()
        self.rfile.close()

flush_headers

def flush_headers(
    self
)
View Source
    def flush_headers(self):
        if hasattr(self, '_headers_buffer'):
            self.wfile.write(b"".join(self._headers_buffer))
            self._headers_buffer = []

handle

def handle(
    self
)

Handle multiple requests if necessary.

View Source
    def handle(self):
        """Handle multiple requests if necessary."""
        self.close_connection = True

        self.handle_one_request()
        while not self.close_connection:
            self.handle_one_request()

handle_expect_100

def handle_expect_100(
    self
)

Decide what to do with an "Expect: 100-continue" header.

If the client is expecting a 100 Continue response, we must respond with either a 100 Continue or a final response before waiting for the request body. The default is to always respond with a 100 Continue. You can behave differently (for example, reject unauthorized requests) by overriding this method.

This method should either return True (possibly after sending a 100 Continue response) or send an error response and return False.

View Source
    def handle_expect_100(self):
        """Decide what to do with an "Expect: 100-continue" header.

        If the client is expecting a 100 Continue response, we must
        respond with either a 100 Continue or a final response before
        waiting for the request body. The default is to always respond
        with a 100 Continue. You can behave differently (for example,
        reject unauthorized requests) by overriding this method.

        This method should either return True (possibly after sending
        a 100 Continue response) or send an error response and return
        False.

        """
        self.send_response_only(HTTPStatus.CONTINUE)
        self.end_headers()
        return True

handle_message

def handle_message(
    self,
    notification_type: str,
    training_id: uuid.UUID,
    data: Any
)

Forward the notification to the corresponding client communication module function.

Parameters:

Name Type Description Default
notification_type str type of the notification None
training_id UUID UUID of the corresponding training None
data Any notification body None
View Source
    def handle_message(self, notification_type: str, training_id: UUID, data: Any):
        """
        Forward the notification to the corresponding client communication module function.

        Args:
            notification_type (str): type of the notification
            training_id (UUID): UUID of the corresponding training
            data (Any): notification body
        """
        self._logger.info(f"receive notification '{notification_type}' for training '{training_id}'")
        match notification_type:
            case "TRAINING_START":
                Communication.init_training(
                    training_id=training_id,
                    model_id=UUID(data["global_model_uuid"]),
                )
                self.send_response(200)
            case "UPDATE_ROUND_START":
                Communication.start_training(
                    training_id=training_id,
                    round=int(data["round"]),
                    model_id=UUID(data["global_model_uuid"]),
                )
                self.send_response(202)
            case "MODEL_TEST_ROUND":
                Communication.start_testing(
                    training_id=training_id,
                    round=int(data["round"]),
                    model_id=UUID(data["global_model_uuid"]),
                )
                self.send_response(202)
            case "TRAINING_FINISHED":
                Communication.end_training(
                    training_id=training_id,
                    model_id=UUID(data["global_model_uuid"]),
                )
                self.send_response(200)
            case _:
                status_code = Communication.unknown_message(
                    mesage_type=data["notification_type"],
                    training_id=training_id,
                    data=data["data"],
                )
                self.send_response(status_code)

handle_one_request

def handle_one_request(
    self
)

Handle a single HTTP request.

You normally don't need to override this method; see the class doc string for information on how to handle specific HTTP commands such as GET and POST.

View Source
    def handle_one_request(self):
        """Handle a single HTTP request.

        You normally don't need to override this method; see the class
        __doc__ string for information on how to handle specific HTTP
        commands such as GET and POST.

        """
        try:
            self.raw_requestline = self.rfile.readline(65537)
            if len(self.raw_requestline) > 65536:
                self.requestline = ''
                self.request_version = ''
                self.command = ''
                self.send_error(HTTPStatus.REQUEST_URI_TOO_LONG)
                return
            if not self.raw_requestline:
                self.close_connection = True
                return
            if not self.parse_request():
                # An error code has been sent, just exit
                return
            mname = 'do_' + self.command
            if not hasattr(self, mname):
                self.send_error(
                    HTTPStatus.NOT_IMPLEMENTED,
                    "Unsupported method (%r)" % self.command)
                return
            method = getattr(self, mname)
            method()
            self.wfile.flush() #actually send the response if not already done.
        except TimeoutError as e:
            #a read or a write timed out.  Discard this connection
            self.log_error("Request timed out: %r", e)
            self.close_connection = True
            return

log_date_time_string

def log_date_time_string(
    self
)

Return the current time formatted for logging.

View Source
    def log_date_time_string(self):
        """Return the current time formatted for logging."""
        now = time.time()
        year, month, day, hh, mm, ss, x, y, z = time.localtime(now)
        s = "%02d/%3s/%04d %02d:%02d:%02d" % (
                day, self.monthname[month], year, hh, mm, ss)
        return s

log_error

def log_error(
    self,
    format,
    *args
)

Log an error.

This is called when a request cannot be fulfilled. By default it passes the message on to log_message().

Arguments are the same as for log_message().

XXX This should go to the separate error log.

View Source
    def log_error(self, format, *args):
        """Log an error.

        This is called when a request cannot be fulfilled.  By
        default it passes the message on to log_message().

        Arguments are the same as for log_message().

        XXX This should go to the separate error log.

        """

        self.log_message(format, *args)

log_message

def log_message(
    self,
    format,
    *args
)

Log an arbitrary message.

This is used by all other logging functions. Override it if you have specific logging wishes.

The first argument, FORMAT, is a format string for the message to be logged. If the format string contains any % escapes requiring parameters, they should be specified as subsequent arguments (it's just like printf!).

The client ip and current date/time are prefixed to every message.

Unicode control characters are replaced with escaped hex before writing the output to stderr.

View Source
    def log_message(self, format, *args):
        """Log an arbitrary message.

        This is used by all other logging functions.  Override
        it if you have specific logging wishes.

        The first argument, FORMAT, is a format string for the
        message to be logged.  If the format string contains
        any % escapes requiring parameters, they should be
        specified as subsequent arguments (it's just like
        printf!).

        The client ip and current date/time are prefixed to
        every message.

        Unicode control characters are replaced with escaped hex
        before writing the output to stderr.

        """

        message = format % args
        sys.stderr.write("%s - - [%s] %s\n" %
                         (self.address_string(),
                          self.log_date_time_string(),
                          message.translate(self._control_char_table)))

log_request

def log_request(
    self,
    code='-',
    size='-'
)

Log an accepted request.

This is called by send_response().

View Source
    def log_request(self, code='-', size='-'):
        """Log an accepted request.

        This is called by send_response().

        """
        if isinstance(code, HTTPStatus):
            code = code.value
        self.log_message('"%s" %s %s',
                         self.requestline, str(code), str(size))

parse_request

def parse_request(
    self
)

Parse a request (internal).

The request should be stored in self.raw_requestline; the results are in self.command, self.path, self.request_version and self.headers.

Return True for success, False for failure; on failure, any relevant error response has already been sent back.

View Source
    def parse_request(self):
        """Parse a request (internal).

        The request should be stored in self.raw_requestline; the results
        are in self.command, self.path, self.request_version and
        self.headers.

        Return True for success, False for failure; on failure, any relevant
        error response has already been sent back.

        """
        self.command = None  # set in case of error on the first line
        self.request_version = version = self.default_request_version
        self.close_connection = True
        requestline = str(self.raw_requestline, 'iso-8859-1')
        requestline = requestline.rstrip('\r\n')
        self.requestline = requestline
        words = requestline.split()
        if len(words) == 0:
            return False

        if len(words) >= 3:  # Enough to determine protocol version
            version = words[-1]
            try:
                if not version.startswith('HTTP/'):
                    raise ValueError
                base_version_number = version.split('/', 1)[1]
                version_number = base_version_number.split(".")
                # RFC 2145 section 3.1 says there can be only one "." and
                #   - major and minor numbers MUST be treated as
                #      separate integers;
                #   - HTTP/2.4 is a lower version than HTTP/2.13, which in
                #      turn is lower than HTTP/12.3;
                #   - Leading zeros MUST be ignored by recipients.
                if len(version_number) != 2:
                    raise ValueError
                version_number = int(version_number[0]), int(version_number[1])
            except (ValueError, IndexError):
                self.send_error(
                    HTTPStatus.BAD_REQUEST,
                    "Bad request version (%r)" % version)
                return False
            if version_number >= (1, 1) and self.protocol_version >= "HTTP/1.1":
                self.close_connection = False
            if version_number >= (2, 0):
                self.send_error(
                    HTTPStatus.HTTP_VERSION_NOT_SUPPORTED,
                    "Invalid HTTP version (%s)" % base_version_number)
                return False
            self.request_version = version

        if not 2 <= len(words) <= 3:
            self.send_error(
                HTTPStatus.BAD_REQUEST,
                "Bad request syntax (%r)" % requestline)
            return False
        command, path = words[:2]
        if len(words) == 2:
            self.close_connection = True
            if command != 'GET':
                self.send_error(
                    HTTPStatus.BAD_REQUEST,
                    "Bad HTTP/0.9 request type (%r)" % command)
                return False
        self.command, self.path = command, path

        # gh-87389: The purpose of replacing '//' with '/' is to protect
        # against open redirect attacks possibly triggered if the path starts
        # with '//' because http clients treat //path as an absolute URI
        # without scheme (similar to http://path) rather than a path.
        if self.path.startswith('//'):
            self.path = '/' + self.path.lstrip('/')  # Reduce to a single /

        # Examine the headers and look for a Connection directive.
        try:
            self.headers = http.client.parse_headers(self.rfile,
                                                     _class=self.MessageClass)
        except http.client.LineTooLong as err:
            self.send_error(
                HTTPStatus.REQUEST_HEADER_FIELDS_TOO_LARGE,
                "Line too long",
                str(err))
            return False
        except http.client.HTTPException as err:
            self.send_error(
                HTTPStatus.REQUEST_HEADER_FIELDS_TOO_LARGE,
                "Too many headers",
                str(err)
            )
            return False

        conntype = self.headers.get('Connection', "")
        if conntype.lower() == 'close':
            self.close_connection = True
        elif (conntype.lower() == 'keep-alive' and
              self.protocol_version >= "HTTP/1.1"):
            self.close_connection = False
        # Examine the headers and look for an Expect directive
        expect = self.headers.get('Expect', "")
        if (expect.lower() == "100-continue" and
                self.protocol_version >= "HTTP/1.1" and
                self.request_version >= "HTTP/1.1"):
            if not self.handle_expect_100():
                return False
        return True

send_error

def send_error(
    self,
    code,
    message=None,
    explain=None
)

Send and log an error reply.

Arguments are * code: an HTTP error code 3 digits * message: a simple optional 1 line reason phrase. *( HTAB / SP / VCHAR / %x80-FF ) defaults to short entry matching the response code * explain: a detailed message defaults to the long entry matching the response code.

This sends an error response (so it must be called before any output has been generated), logs the error, and finally sends a piece of HTML explaining the error to the user.

View Source
    def send_error(self, code, message=None, explain=None):
        """Send and log an error reply.

        Arguments are
        * code:    an HTTP error code
                   3 digits
        * message: a simple optional 1 line reason phrase.
                   *( HTAB / SP / VCHAR / %x80-FF )
                   defaults to short entry matching the response code
        * explain: a detailed message defaults to the long entry
                   matching the response code.

        This sends an error response (so it must be called before any
        output has been generated), logs the error, and finally sends
        a piece of HTML explaining the error to the user.

        """

        try:
            shortmsg, longmsg = self.responses[code]
        except KeyError:
            shortmsg, longmsg = '???', '???'
        if message is None:
            message = shortmsg
        if explain is None:
            explain = longmsg
        self.log_error("code %d, message %s", code, message)
        self.send_response(code, message)
        self.send_header('Connection', 'close')

        # Message body is omitted for cases described in:
        #  - RFC7230: 3.3. 1xx, 204(No Content), 304(Not Modified)
        #  - RFC7231: 6.3.6. 205(Reset Content)
        body = None
        if (code >= 200 and
            code not in (HTTPStatus.NO_CONTENT,
                         HTTPStatus.RESET_CONTENT,
                         HTTPStatus.NOT_MODIFIED)):
            # HTML encode to prevent Cross Site Scripting attacks
            # (see bug #1100201)
            content = (self.error_message_format % {
                'code': code,
                'message': html.escape(message, quote=False),
                'explain': html.escape(explain, quote=False)
            })
            body = content.encode('UTF-8', 'replace')
            self.send_header("Content-Type", self.error_content_type)
            self.send_header('Content-Length', str(len(body)))
        self.end_headers()

        if self.command != 'HEAD' and body:
            self.wfile.write(body)

send_header

def send_header(
    self,
    keyword,
    value
)

Send a MIME header to the headers buffer.

View Source
    def send_header(self, keyword, value):
        """Send a MIME header to the headers buffer."""
        if self.request_version != 'HTTP/0.9':
            if not hasattr(self, '_headers_buffer'):
                self._headers_buffer = []
            self._headers_buffer.append(
                ("%s: %s\r\n" % (keyword, value)).encode('latin-1', 'strict'))

        if keyword.lower() == 'connection':
            if value.lower() == 'close':
                self.close_connection = True
            elif value.lower() == 'keep-alive':
                self.close_connection = False

send_response

def send_response(
    self,
    code,
    message=None
)

Add the response header to the headers buffer and log the

response code.

Also send two standard headers with the server software version and the current date.

View Source
    def send_response(self, code, message=None):
        """Add the response header to the headers buffer and log the
        response code.

        Also send two standard headers with the server software
        version and the current date.

        """
        self.log_request(code)
        self.send_response_only(code, message)
        self.send_header('Server', self.version_string())
        self.send_header('Date', self.date_time_string())

send_response_only

def send_response_only(
    self,
    code,
    message=None
)

Send the response header only.

View Source
    def send_response_only(self, code, message=None):
        """Send the response header only."""
        if self.request_version != 'HTTP/0.9':
            if message is None:
                if code in self.responses:
                    message = self.responses[code][0]
                else:
                    message = ''
            if not hasattr(self, '_headers_buffer'):
                self._headers_buffer = []
            self._headers_buffer.append(("%s %d %s\r\n" %
                    (self.protocol_version, code, message)).encode(
                        'latin-1', 'strict'))

setup

def setup(
    self
)
View Source
    def setup(self):
        self.connection = self.request
        if self.timeout is not None:
            self.connection.settimeout(self.timeout)
        if self.disable_nagle_algorithm:
            self.connection.setsockopt(socket.IPPROTO_TCP,
                                       socket.TCP_NODELAY, True)
        self.rfile = self.connection.makefile('rb', self.rbufsize)
        if self.wbufsize == 0:
            self.wfile = _SocketWriter(self.connection)
        else:
            self.wfile = self.connection.makefile('wb', self.wbufsize)

version_string

def version_string(
    self
)

Return the server software version string.

View Source
    def version_string(self):
        """Return the server software version string."""
        return self.server_version + ' ' + self.sys_version