Développer une tâche avec manta-light

Les tâches sont les éléments de base des algorithmes décentralisés dans Manta. Une Task définit les calculs spécifiques exécutés sur chaque nœud d’un Swarm. En utilisant l’API manta-light, les tâches peuvent interagir avec les données locales du nœud, accéder aux paramètres globaux et contribuer aux résultats dans le contexte partagé du Swarm.

Pour plus d’informations, reportez-vous à la documentation de l’API Task.

Tâche Tâche

World, Globals et Results

L’API manta-light fournit trois interfaces principales essentielles pour le développement des tâches : World, Globals et Results.

  • World fournit le contexte d’exécution pour les tâches. Il permet aux tâches d’interagir avec des variables globales partagées (Globals) et de stocker les sorties spécifiques aux tâches (Results). Le World sert de couche de communication entre les tâches, permettant de synchroniser et de partager les données sur différents nœuds. Pour plus de détails, reportez-vous à la documentation de l’API World.

  • Globals sont des variables globales partagées entre toutes les tâches et itérations au sein d’un Swarm. Ces variables peuvent inclure des paramètres tels que des hyperparamètres pour l’entraînement ou d’autres paramètres de configuration qui doivent être accessibles à tous les nœuds. Pour plus de détails, reportez-vous à la documentation de l’API Globals.

  • Results permet aux tâches de stocker des sorties accessibles dans les itérations futures. Par exemple, des métriques de modèle, des paramètres entraînés ou des résultats intermédiaires peuvent être enregistrés et utilisés par les tâches suivantes. Pour plus de détails, reportez-vous à la documentation de l’API Results.

API locale

L’API locale fournit aux tâches un accès aux données locales du nœud. Ces données peuvent inclure des ensembles de données pour l’entraînement, des fichiers de configuration ou d’autres ressources stockées localement sur chaque nœud. L’API locale abstrait les mécanismes de stockage de données, permettant aux tâches de se concentrer sur le traitement des données plutôt que sur la gestion de leur récupération.

Reportez-vous à la documentation de l’API Local pour obtenir des informations détaillées sur l’interaction avec les données locales.

Exemple concret : Swarm d’apprentissage fédéré

Les exemples suivants démontrent un ensemble de tâches formant un Swarm d’apprentissage fédéré simple. Le Swarm se compose de trois tâches principales : Aggregator, Scheduler et Worker. Ces tâches se coordonnent pour entraîner un modèle partagé sur plusieurs nœuds sans partager les données brutes, préservant ainsi la confidentialité.

Tâche Worker (worker_task.py)

La tâche Worker entraîne un modèle local sur les données du nœud, l’évalue et renvoie les résultats au contexte global.

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from manta_light.task import Task
from manta_light.utils import bytes_to_numpy, numpy_to_bytes
from .model import MLP

class Worker(Task):
    def __init__(self):
        super().__init__()

        # Load MNIST dataset from local storage
        raw_data = self.local.get_binary_data("mnist_part.npz")
        self.data = np.load(raw_data)
        self.model = MLP()

    def run(self):
        # Get model weights and hyperparameters from globals
        weights = self.world.globals["global_model_params"]
        self.model.set_weights(bytes_to_numpy(weights))
        hyperparameters = self.world.globals["hyperparameters"]

        # Train the model and collect metrics
        metrics = self.train_model(hyperparameters)

        # Save the metrics and model weights to the results
        self.world.results.add("metrics", metrics)
        self.world.results.add(
            "model_params", numpy_to_bytes(self.model.get_weights())
        )

    def train_model(self, hyperparameters: dict):
        """
        Trains the model using the given hyperparameters.

        Parameters
        ----------
        hyperparameters : dict
            A dictionary containing training parameters like loss function,
            optimizer, and epochs.

        Returns
        -------
        dict
            A dictionary with training metrics including loss and validation
            accuracy.
        """
        X_train, y_train, X_test, y_test = (
            self.data["x_train"],
            self.data["y_train"],
            self.data["x_test"],
            self.data["y_test"],
        )

        # Define loss function and optimizer
        criterion = getattr(nn, hyperparameters["loss"])(
            **hyperparameters.get("loss_params", {})
        )
        optimizer = getattr(optim, hyperparameters["optimizer"])(
            self.model.parameters(), **hyperparameters.get("optimizer_params", {})
        )

        metrics = {"loss": [], "val_loss": [], "val_acc": []}

        # Training loop
        for epoch in range(hyperparameters["epochs"]):
            self.model.train()
            for i in range(0, X_train.shape[0], hyperparameters["batch_size"]):
                optimizer.zero_grad()
                output = self.model(
                    torch.tensor(X_train[i:i + hyperparameters["batch_size"]]).float()
                )
                loss = criterion(
                    output, torch.tensor(y_train[i:i + hyperparameters["batch_size"]])
                )
                loss.backward()
                optimizer.step()

            # Validation
            self.model.eval()
            with torch.no_grad():
                val_output = self.model(torch.tensor(X_test).float())
                val_loss = criterion(val_output, torch.tensor(y_test))
                val_acc = (
                  (val_output.argmax(1) == torch.tensor(y_test)).float().mean().item()
                )

            metrics["loss"].append(loss.item())
            metrics["val_loss"].append(val_loss.item())
            metrics["val_acc"].append(val_acc)

        return metrics

def main():
    Worker().run()

Tâche Aggregator (aggregator.py)

La tâche Aggregator collecte les paramètres du modèle provenant de tous les nœuds, les agrège et met à jour le modèle global.

import numpy as np
from manta_light.task import Task
from manta_light.utils import bytes_to_numpy, numpy_to_bytes

class Aggregator(Task):
    def run(self):
        # Aggregate the models from all nodes
        models = bytes_to_numpy(self.world.results.select("model_params"))
        aggregated_model = self.aggregate_models(list(models.values()))

        # Update the global model parameters
        self.world.globals["global_model_params"] = numpy_to_bytes(aggregated_model)

    def aggregate_models(self, models: list):
        """
        Aggregates the given models by averaging their weights.

        Parameters
        ----------
        models : list
            A list of model parameter dictionaries.

        Returns
        -------
        dict
            A dictionary representing the aggregated model parameters.
        """
        aggregated_model = {}
        for layer in models[0]:
            aggregated_model[layer] = np.mean(
                [model[layer] for model in models], axis=0
            )
        return aggregated_model

def main():
    Aggregator().run()

Tâche Scheduler (scheduler.py)

La tâche Scheduler sélectionne les nœuds pour la prochaine itération d’entraînement en fonction de leur précision de validation et gère la logique de planification du Swarm.

from manta_light.task import Task

class Scheduler(Task):
    def run(self):
        # Select nodes based on validation accuracy metrics
        metrics = self.world.results.select("metrics")
        selected_nodes = self.select_nodes(metrics)
        if selected_nodes:
            self.world.schedule_next_iter(
                node_ids=selected_nodes, task_to_schedule_alias="worker"
            )

    def select_nodes(self, metrics: dict):
        """
        Selects nodes based on their validation accuracy.

        Parameters
        ----------
        metrics : dict
            A dictionary containing metrics for each node.

        Returns
        -------
        list
            A list of node IDs that are selected for the next iteration.
        """
        selected_nodes = []
        val_acc_threshold = 0.95

        for node_id, metr in metrics.items():
            if metr["val_acc"][-1] < val_acc_threshold:
                selected_nodes.append(node_id)

        if not selected_nodes:
            self.world.stop_swarm()

        return selected_nodes

def main():
    Scheduler().run()

Cette section fournit un exemple concret de la façon dont les tâches interagissent avec les données locales et globales au sein du Swarm, formant une boucle complète d’apprentissage fédéré qui exploite la nature décentralisée de la plateforme.