Scheduler¶
The Scheduler
task selects nodes for the next training iteration based on their validation accuracy and manages the Swarm’s scheduling logic.
# my_scheduler.py
from manta_light.task import Task
class Scheduler(Task):
def run(self):
# Select all results from the worker
metrics = self.world.results.select("metrics")
# Select nodes based on validation accuracy metrics
selected_nodes = self.select_nodes(metrics)
if selected_nodes:
self.world.schedule_next_iter(
node_ids=selected_nodes, task_to_schedule_alias="worker"
)
def select_nodes(self, metrics: dict):
"""
Selects nodes based on their validation accuracy.
Parameters
----------
metrics : dict
A dictionary containing metrics for each node.
Returns
-------
list
A list of node IDs that are selected for the next iteration.
"""
selected_nodes = []
val_acc_threshold = 0.95
# Build a list of nodes which did not have converged on their own data
for node_id, metr in metrics.items():
if metr["val_acc"][-1] < val_acc_threshold:
selected_nodes.append(node_id)
if not selected_nodes: # Condition to stop the swarm
self.world.stop_swarm()
return selected_nodes
# Function for allowing task's execution
def main():
Scheduler().run()