Source code for golf_federated.server.communication.api.download

# -*- coding: utf-8 -*-
# @Author             : GZH
# @Created Time       : 2022/11/3 13:14
# @Email              : guozh29@mail2.sysu.edu.cn
# @Last Modified By   : GZH
# @Last Modified Time : 2022/11/3 13:14

from flask import Response

from golf_federated.server.utils.cost import sim_cost
from golf_federated.utils.log import loggerhear

# ToDo: Predefined server and task name, database will be introduced later.

server_name = 'server1'
task_name = 'task1'


[docs]def download_model(serverhere: object) -> Response: """ Model download method for API. Args: serverhere (golf_federated.server.process.config.device.base.MultiDeviceServer): Server object. Returns: Response: Model weight file stream. """ # TODO: Judge the client that sent the request. loggerhear.log("Server Info ", "Server %s is being request to download the global model." % serverhere.server_name) # Get the Task object. task = serverhere.get_task(task_name) # Get file path to global model. file_path = task.weight_path # Calculate communication cost. task.round_cost += sim_cost( data=file_path, file_path=True, communication_num=1 ) return Response(send_chunk(file_path), content_type='application/octet-stream')
[docs]def download_info(serverhere: object) -> Response: """ Task info download method for API. Args: serverhere (golf_federated.server.process.config.device.base.MultiDeviceServer): Server object. Returns: Response: Task info file stream. """ # TODO: Judge the client that sent the request. loggerhear.log("Server Info ", "Server %s is being request to download the task info." % serverhere.server_name) # Get the Task object. task = serverhere.get_task(task_name) # Get file path to task info. file_path = task.info_path # Calculate communication cost. task.round_cost += sim_cost( data=file_path, file_path=True, communication_num=1 ) return Response(send_chunk(file_path), content_type='application/octet-stream')
[docs]def send_chunk(file_path: str): """ File stream transfer. Args: file_path (str): File path to transfer. """ with open(file_path, 'rb') as target_file: while True: chunk = target_file.read(2 * 1024 * 1024) if not chunk: break yield chunk