Source code for golf_federated.client.process.config.device.base

# -*- coding: utf-8 -*-
# @Author             : GZH
# @Created Time       : 2022/11/10 13:30
# @Email              : guozh29@mail2.sysu.edu.cn
# @Last Modified By   : GZH
# @Last Modified Time : 2022/11/10 13:30

from abc import abstractmethod
from typing import List

from numpy import ndarray

from golf_federated.utils.log import loggerhear
from golf_federated.utils.data import calculate_IW


[docs]class BaseClient(object): """ Client object class, the class function supports the main operation of the client. """ def __init__( self, client_name: str, ) -> None: """ Initialize the Client object. Args: client_name (str): Name of the Client object. """ # Initialize object properties. self.client_name = client_name self.trainer = None self.field = []
[docs] @abstractmethod def init_trainer( self, *args ) -> None: """ Abstract method for initializing the Trainer. Args: *args: Variable number of parameters, see instantiation methods for details. """ pass
[docs] def train(self) -> None: """ Perform local training. """ # Call the trainer to perform local training. loggerhear.log("Client Info ", "Training Round %d on %s!" % (self.trainer.trained_num + 1, self.client_name)) self.trainer.train()
[docs] def predict(self) -> ndarray: """ Perform model prediction. Returns: Numpy.ndarray: Prediction results """ # Call the trainer to perform model prediction. loggerhear.log("Client Info ", "Model prediction on %s!" % self.client_name) return self.trainer.predict()
[docs] def get_model(self) -> List: """ Get the current local model weight. Returns: List: Current local model weight """ # Call the trainer to get the current local model weight. return self.trainer.get_model()
[docs] @abstractmethod def update_model( self, *args ) -> None: """ Abstract method for updating local model weight. Args: *args: Variable number of parameters, see instantiation methods for details. """ pass
[docs] def get_field(self) -> dict: """ Get data of specified fields for model aggregation. Returns: Dict: Data of specified fields for model aggregation. """ # Initialize dictionary. field = dict() # Fill data to the dictionary. for f in self.field: # Judge field name. if f == 'clientRound': # Number of local training rounds. field['clientRound'] = self.trainer.trained_num elif f == 'informationRichness': # Information richness of local data. field['informationRichness'] = calculate_IW(self.trainer.get_train_label()) elif f == 'dataSize': # Size of local data. field['dataSize'] = self.trainer.get_train_data().shape[0] return field
[docs] def stop(self) -> None: """ Stop local training. """ # Call the trainer to stop local training. loggerhear.log("Client Info ", "Stop training on %s!" % self.client_name) try: self.trainer.stop() except: self.trainer.stop_trainer()