Source code for golf_federated.server.process.config.task.base

# -*- coding: utf-8 -*-
# @Author             : GZH
# @Created Time       : 2022/11/3 14:45
# @Email              : guozh29@mail2.sysu.edu.cn
# @Last Modified By   : GZH
# @Last Modified Time : 2022/11/3 14:45

import json
import os
import time
import zipfile
from abc import abstractmethod
from queue import Queue
from typing import List

import numpy as np

from golf_federated.server.process.config.model.base import BaseModel
from golf_federated.server.process.strategy.aggregation.base import BaseFed
from golf_federated.server.process.strategy.evaluation.base import BaseEval
from golf_federated.server.process.strategy.selection.base import BaseSelect


[docs]class BaseTask(object): """ Task object class, the class function supports the main operation of task on Server. """ def __init__( self, task_name: str, maxround: int, synchronous: bool, aggregation: BaseFed, evaluation: BaseEval, model: BaseModel, select: BaseSelect, module_path: str, isdocker: bool = False, image_name: str = '' ) -> None: """ Initialize the Task object. Args: task_name (str): Name of the task. maxround (int): Maximum number of aggregation rounds. synchronous (bool): Whether the task is synchronous. aggregation (golf_federated.server.process.strategy.aggregation.base.BaseFed): Aggregation strategy object. evaluation (golf_federated.server.process.strategy.evaluation.base.BaseEval): Evaluation strategy object. model (golf_federated.server.process.config.model.base.BaseModel): Model object. select (golf_federated.server.process.strategy.selection.base.BaseSelect): Select strategy object. module_path (str): File path to model module. isdocker (bool): Whether the task requires Docker. Default as False. image_name (str): Name of Docker image. Default as ''. """ # Initialize object properties. self.task_stop = False self.task_name = task_name self.maxround = maxround self.synchronous = synchronous self.aggregation = aggregation self.model = model self.evaluation = evaluation self.select = select self.module_path = module_path self.isdocker = isdocker self.image_name = image_name self.client_list = [] self.round_time = time.time() self.time_record = [] self.round_cost = 0 self.cost_record = [] self.info_path = '' self.weight_path = ''
[docs] def start( self, client_list: List, ) -> None: """ Start Task. Args: client_list (list): List of clients for this task. """ # Update client list. self.client_list = client_list # Initialize object properties. self.round_time = time.time() self.round_cost = 0 self.task_stop = False
[docs] @abstractmethod def start_aggregation( self, aggregation_parameter: Queue ) -> bool: """ Judge whether the conditions for starting aggregation have been met. Args: aggregation_parameter (queue.Queue): Queue for storing aggregated parameters. Returns: Bool: Whether to start aggregation. """ pass
[docs] def run_aggregation( self, aggregation_parameter: Queue ) -> bool: """ Run global model aggregation. Args: aggregation_parameter (queue.Queue): Queue for storing aggregated parameters. Returns: Bool: Whether aggregation is executed. """ # Judge whether to start aggregation. if self.start_aggregation(aggregation_parameter=aggregation_parameter): # Run global model aggregation. self.model.model_aggre(aggregation=self.aggregation, parameter=aggregation_parameter, record=self.evaluation.get_record()) return True else: # Conditions for starting aggregation have not been met. return False
[docs] def run_evaluation(self) -> bool: """ Run global model evaluation. Returns: Bool: Evaluation result, indicating the continuation or completion of the task. """ # Call model object to perform evaluation. self.model.model_eval(evaluation=self.evaluation) # Record time and initialize object property. self.time_record.append(time.time() - self.round_time) self.round_time = time.time() # Record communication cost and initialize object property. self.cost_record.append(self.round_cost) self.round_cost = 0 # Multiple evaluation conditions. return self.evaluation.reach_target() or self.evaluation.reach_convergence() or self.aggregation.aggregation_version >= self.maxround
[docs] def select_clients(self) -> List: """ Select clients. Returns: List: Selected clients. """ return self.select.select()
[docs] def info_tozip(self) -> None: """ Save task info to zip. """ # Temporary folder. if not os.path.isdir('temp'): os.mkdir('temp') # Get aggregation field. agg_dict = { "aggregationField": self.aggregation.get_field() } # Save aggregation field to Json file. with open("temp/task_info.json", "w") as f: json.dump(agg_dict, f) # Task info zip. file_path = 'temp/' + self.task_name + '_info.zip' # Create the zip. info_zipper = zipfile.ZipFile( file_path, 'w', compression=zipfile.ZIP_DEFLATED ) # Write model module or Docker image into the zip. info_zipper.write( filename=self.module_path, arcname=self.image_name + '.tar' if self.isdocker else 'module.py' ) # Write Json file into the zip. info_zipper.write( filename='temp/task_info.json', arcname='task_info.json' ) # Close the zip. info_zipper.close() # Update corresponding file path. self.info_path = file_path
[docs] def weight_tofile(self) -> None: """ Save model weight to zip. """ # Save model weight. np.save('weight.npy', self.model.get_weight()) # Model weight zip. file_path = 'temp/' + self.task_name + '_weight.zip' # Create the zip. weight_zipper = zipfile.ZipFile( file_path, 'w', compression=zipfile.ZIP_DEFLATED ) # Write model weight into the zip. weight_zipper.write( filename='weight.npy', arcname='weight.npy' ) # Close the zip. weight_zipper.close() # Update corresponding file path. self.weight_path = file_path