Source code for golf_federated.server.process.config.device.multidevice

# -*- coding: utf-8 -*-
# @Author             : GZH
# @Created Time       : 2022/11/3 11:01
# @Email              : guozh29@mail2.sysu.edu.cn
# @Last Modified By   : GZH
# @Last Modified Time : 2022/11/3 11:01
from typing import List

from golf_federated.server.communication.sse.schedule import publish_task_init, publish_update_model, publish_stop_train
from golf_federated.server.process.config.device.base import BaseServer
from golf_federated.server.process.config.task.base import BaseTask
from golf_federated.server.process.config.port.api import run_restful
from golf_federated.server.process.config.port.sse import init_sse
from golf_federated.utils.log import loggerhear


[docs]class MultiDeviceServer(BaseServer): """ Multi-Device Server object class, inheriting from Server class. """ def __init__( self, server_name: str, client_pool: List = [], api_host: str = '127.0.0.1', api_port: str = '7788', sse_host: str = '127.0.0.1', sse_port: str = '6379', sse_db: int = 6, ) -> None: """ Initialize the Multi-Device Server object. Args: server_name (str): Name of the Server object. client_pool (list): Client pool with client object or client name as element. Default as []. api_host (str): Host name to connect to the API host. Default as '127.0.0.1'. api_port (str): Port number to connect to the API host. Default as '7788'. sse_host (str): Host name to connect to the SSE host. Default as '127.0.0.1'. sse_port (str): Port number to connect to the SSE host. Default as '6379'. sse_db (int): Adopted Database. Default as 6. """ # Super class init. super().__init__(server_name=server_name, client_pool=client_pool) # Initialize object properties. self.api_host = api_host self.api_port = api_port self.sse_host = sse_host self.sse_port = sse_port self.sse_db = sse_db
[docs] def start_server(self) -> None: """ Start server. """ # Initialize SSE. config = init_sse( server=self, host=self.sse_host, port=self.sse_port, db=self.sse_db ) # Start restful API. run_restful( config=config, host=self.api_host, port=self.api_port, )
[docs] def start_task( self, task: BaseTask ) -> None: """ Start Task. Args: task (golf_federated.server.process.config.task.base.BaseTask): Corresponding Task object. """ # Update the client list of Select object. task.select.client_list = self.client_pool # Get selected clients. client_selected = task.select_clients() # Start Task. task.start(client_selected) loggerhear.log('Server Info ', "Task %s on Server %s starts." % (task.task_name, self.server_name)) # Record the started Task object. self.task_list.append(task) # Save task info. task.info_tozip() # Save initialized global model weight. task.weight_tofile() # Publish 'TaskInit' info. publish_task_init( host=self.sse_host, port=self.sse_port, db=self.sse_db )
[docs] def client_register( self, client_name: str ) -> None: """ Client register. Args: client_name (str): Name of the client to register. """ self.client_pool.append(client_name)
[docs] def task_update_model( self, task: BaseTask ) -> None: """ Update model. Args: task (golf_federated.server.process.config.task.base.BaseTask): Corresponding Task object. """ # Update the saved global model weight file. task.weight_tofile() # Publish 'UpdateModel' info. publish_update_model( host=self.sse_host, port=self.sse_port, db=self.sse_db )
[docs] def task_stop( self, task: BaseTask ) -> None: """ Stop task. Args: task (golf_federated.server.process.config.task.base.BaseTask): Corresponding Task object. """ # Publish 'StopTrain' info. publish_stop_train( host=self.sse_host, port=self.sse_port, db=self.sse_db ) task.task_stop = True