Scheduler

The Scheduler task selects nodes for the next training iteration based on their validation accuracy and manages the Swarm’s scheduling logic.

# my_scheduler.py
from manta_light.task import Task

class Scheduler(Task):
    def run(self):
        # Select all results from the worker
        metrics = self.world.results.select("metrics")
        # Select nodes based on validation accuracy metrics
        selected_nodes = self.select_nodes(metrics)
        if selected_nodes:
            self.world.schedule_next_iter(
                node_ids=selected_nodes, task_to_schedule_alias="worker"
            )

    def select_nodes(self, metrics: dict):
        """
        Selects nodes based on their validation accuracy.

        Parameters
        ----------
        metrics : dict
            A dictionary containing metrics for each node.

        Returns
        -------
        list
            A list of node IDs that are selected for the next iteration.
        """
        selected_nodes = []
        val_acc_threshold = 0.95

        # Build a list of nodes which did not have converged on their own data
        for node_id, metr in metrics.items():
            if metr["val_acc"][-1] < val_acc_threshold:
                selected_nodes.append(node_id)

        if not selected_nodes: # Condition to stop the swarm
            self.world.stop_swarm()

        return selected_nodes

# Function for allowing task's execution
def main():
    Scheduler().run()