Source code for golf_federated.client.process.config.trainer.base

# -*- coding: utf-8 -*-
# @Author             : GZH
# @Created Time       : 2022/11/10 17:55
# @Email              : guozh29@mail2.sysu.edu.cn
# @Last Modified By   : GZH
# @Last Modified Time : 2022/11/10 17:55

from abc import abstractmethod
from typing import List

from numpy import ndarray


[docs]class BaseTrainer(object): """ Trainer object class, the class function supports the main operation of trainer on Client. """ def __init__( self, mode: str ) -> None: """ Initialize the Trainer object. Args: mode (str): Mode of Trainer. Now support 'direct' and 'docker'. """ # Initialize object properties. self.mode = mode self.trained_num = 0
[docs] @abstractmethod def train(self) -> None: """ Abstract method for training. """ pass
[docs] @abstractmethod def predict( self, data: ndarray ) -> ndarray: """ Abstract method for prediction. Args: data (numpy.ndarray): Data values for prediction. Returns: Numpy.ndarray: Prediction result. """ pass
[docs] @abstractmethod def update_model( self, new_weight: List ): """ Abstract method for model weight update. Args: new_weight (List): Model weight for update. """ pass
[docs] @abstractmethod def get_model(self) -> List: """ Abstract method for model weight getting. Returns: List: Model weight. """ pass
[docs] @abstractmethod def get_train_data(self) -> ndarray: """ Abstract method for data values getting. Returns: Numpy.ndarray: Data values. """ pass
[docs] @abstractmethod def get_train_label(self) -> ndarray: """ Abstract method for data labels getting. Returns: Numpy.ndarray: Data labels. """ pass
[docs] @abstractmethod def stop_trainer(self) -> None: """ Abstract method for trainer stopping. """ pass