Source code for fedscale.cloud.client_manager

import logging
import math
import pickle
from random import Random
from typing import Dict, List

from fedscale.cloud.internal.client_metadata import ClientMetadata


[docs]class ClientManager: def __init__(self, mode, args, sample_seed=233): self.client_metadata = {} self.client_on_hosts = {} self.mode = mode self.filter_less = args.filter_less self.filter_more = args.filter_more self.ucb_sampler = None if self.mode == 'oort': from thirdparty.oort.oort import create_training_selector self.ucb_sampler = create_training_selector(args=args) self.feasibleClients = [] self.rng = Random() self.rng.seed(sample_seed) self.count = 0 self.feasible_samples = 0 self.user_trace = None self.args = args if args.device_avail_file is not None: with open(args.device_avail_file, 'rb') as fin: self.user_trace = pickle.load(fin) self.user_trace_keys = list(self.user_trace.keys())
[docs] def register_client(self, host_id: int, client_id: int, size: int, speed: Dict[str, float], duration: float = 1) -> None: """Register client information to the client manager. Args: host_id (int): executor Id. client_id (int): client Id. size (int): number of samples on this client. speed (Dict[str, float]): device speed (e.g., compuutation and communication). duration (float): execution latency. """ uniqueId = self.getUniqueId(host_id, client_id) user_trace = None if self.user_trace is None else self.user_trace[self.user_trace_keys[int( client_id) % len(self.user_trace)]] self.client_metadata[uniqueId] = ClientMetadata(host_id, client_id, speed, user_trace) # remove clients if size >= self.filter_less and size <= self.filter_more: self.feasibleClients.append(client_id) self.feasible_samples += size if self.mode == "oort": feedbacks = {'reward': min(size, self.args.local_steps * self.args.batch_size), 'duration': duration, } self.ucb_sampler.register_client(client_id, feedbacks=feedbacks) else: del self.client_metadata[uniqueId]
[docs] def getAllClients(self): return self.feasibleClients
[docs] def getAllClientsLength(self): return len(self.feasibleClients)
[docs] def getClient(self, client_id): return self.client_metadata[self.getUniqueId(0, client_id)]
[docs] def registerDuration(self, client_id, batch_size, local_steps, upload_size, download_size): if self.mode == "oort" and self.getUniqueId(0, client_id) in self.client_metadata: exe_cost = self.client_metadata[self.getUniqueId(0, client_id)].get_completion_time( batch_size=batch_size, local_steps=local_steps, upload_size=upload_size, download_size=download_size ) self.ucb_sampler.update_duration( client_id, exe_cost['computation'] + exe_cost['communication'])
[docs] def get_completion_time(self, client_id, batch_size, local_steps, upload_size, download_size): return self.client_metadata[self.getUniqueId(0, client_id)].get_completion_time( batch_size=batch_size, local_steps=local_steps, upload_size=upload_size, download_size=download_size )
[docs] def registerSpeed(self, host_id, client_id, speed): uniqueId = self.getUniqueId(host_id, client_id) self.client_metadata[uniqueId].speed = speed
[docs] def registerScore(self, client_id, reward, auxi=1.0, time_stamp=0, duration=1., success=True): self.register_feedback(client_id, reward, auxi=auxi, time_stamp=time_stamp, duration=duration, success=success)
[docs] def register_feedback(self, client_id: int, reward: float, auxi: float = 1.0, time_stamp: float = 0, duration: float = 1., success: bool = True) -> None: """Collect client execution feedbacks of last round. Args: client_id (int): client Id. reward (float): execution utilities (processed feedbacks). auxi (float): unprocessed feedbacks. time_stamp (float): current wall clock time. duration (float): system execution duration. success (bool): whether this client runs successfully. """ # currently, we only use distance as reward if self.mode == "oort": feedbacks = { 'reward': reward, 'duration': duration, 'status': True, 'time_stamp': time_stamp } self.ucb_sampler.update_client_util(client_id, feedbacks=feedbacks)
[docs] def registerClientScore(self, client_id, reward): self.client_metadata[self.getUniqueId(0, client_id)].register_reward(reward)
[docs] def get_score(self, host_id, client_id): uniqueId = self.getUniqueId(host_id, client_id) return self.client_metadata[uniqueId].get_score()
[docs] def getClientsInfo(self): clientInfo = {} for i, client_id in enumerate(self.client_metadata.keys()): client = self.client_metadata[client_id] clientInfo[client.client_id] = client.distance return clientInfo
[docs] def next_client_id_to_run(self, host_id): init_id = host_id - 1 lenPossible = len(self.feasibleClients) while True: client_id = str(self.feasibleClients[init_id]) csize = self.client_metadata[client_id].size if csize >= self.filter_less and csize <= self.filter_more: return int(client_id) init_id = max( 0, min(int(math.floor(self.rng.random() * lenPossible)), lenPossible - 1))
[docs] def getUniqueId(self, host_id, client_id): return str(client_id)
# return (str(host_id) + '_' + str(client_id))
[docs] def clientSampler(self, client_id): return self.client_metadata[self.getUniqueId(0, client_id)].size
[docs] def clientOnHost(self, client_ids, host_id): self.client_on_hosts[host_id] = client_ids
[docs] def getCurrentclient_ids(self, host_id): return self.client_on_hosts[host_id]
[docs] def getClientLenOnHost(self, host_id): return len(self.client_on_hosts[host_id])
[docs] def getClientSize(self, client_id): return self.client_metadata[self.getUniqueId(0, client_id)].size
[docs] def getSampleRatio(self, client_id, host_id, even=False): totalSampleInTraining = 0. if not even: for key in self.client_on_hosts.keys(): for client in self.client_on_hosts[key]: uniqueId = self.getUniqueId(key, client) totalSampleInTraining += self.client_metadata[uniqueId].size # 1./len(self.client_on_hosts.keys()) return float(self.client_metadata[self.getUniqueId(host_id, client_id)].size) / float(totalSampleInTraining) else: for key in self.client_on_hosts.keys(): totalSampleInTraining += len(self.client_on_hosts[key]) return 1. / totalSampleInTraining
[docs] def getFeasibleClients(self, cur_time): if self.user_trace is None: clients_online = self.feasibleClients else: clients_online = [client_id for client_id in self.feasibleClients if self.client_metadata[self.getUniqueId( 0, client_id)].is_active(cur_time)] logging.info(f"Wall clock time: {round(cur_time)}, {len(clients_online)} clients online, " + f"{len(self.feasibleClients) - len(clients_online)} clients offline") return clients_online
[docs] def isClientActive(self, client_id, cur_time): return self.client_metadata[self.getUniqueId(0, client_id)].is_active(cur_time)
[docs] def select_participants(self, num_of_clients: int, cur_time: float = 0) -> List[int]: """Select participating clients for current execution task. Args: num_of_clients (int): number of participants to select. cur_time (float): current wall clock time. Returns: List[int]: indices of selected clients. """ self.count += 1 clients_online = self.getFeasibleClients(cur_time) if len(clients_online) <= num_of_clients: return clients_online pickled_clients = None clients_online_set = set(clients_online) if self.mode == "oort" and self.count > 1: pickled_clients = self.ucb_sampler.select_participant( num_of_clients, feasible_clients=clients_online_set) else: self.rng.shuffle(clients_online) client_len = min(num_of_clients, len(clients_online) - 1) pickled_clients = clients_online[:client_len] return pickled_clients
[docs] def resampleClients(self, num_of_clients, cur_time=0): return self.select_participants(num_of_clients, cur_time)
[docs] def getAllMetrics(self): if self.mode == "oort": return self.ucb_sampler.getAllMetrics() return {}
[docs] def getDataInfo(self): return {'total_feasible_clients': len(self.feasibleClients), 'total_num_samples': self.feasible_samples}
[docs] def getClientReward(self, client_id): return self.ucb_sampler.get_client_reward(client_id)
[docs] def get_median_reward(self): if self.mode == 'oort': return self.ucb_sampler.get_median_reward() return 0.