Développer une tâche avec manta-light

Les tâches sont les éléments de base des algorithmes décentralisés dansManta. Une Task définit les calculs spécifiques qui sont exécutéssur chaque nœud d’un Swarm. En utilisant l’API manta-light, les tâches peuvent interagir avec les données locales du nœud, accéder aux paramètres globaux et contribuer aux résultats dans le contexte partagé du Swarm.

Pour plus d’informations, reportez-vous à la Task Documentation de l’API.

Tâche Tâche

Monde, Globals et résultats

L’API manta-light fournit trois interfaces principales qui sontessentiel pour le développement des tâches : World, Globals et Results.

  • World fournit le contexte d’exécution pour les tâches. Il permet aux tâches d’interagir avec des variables globales partagées (Globals) et stocker des tâches spécifiques sorties (Results). Le World sert de moyen de communication couche entre les tâches, permettant de synchroniser et de partager les données sur différents nœuds. Pour plus de détails, reportez-vous à la World Documentation de l’API.

  • Globals sont des variables globalespartagées entre toutes les tâches et itérations au sein d’un Swarm. Ces variables peut inclure des paramètres tels que des hyperparamètres pour la formation ou autres paramètres de configuration qui doivent être accessibles à tous les nœuds. Pour en savoir plus pour plus de détails, reportez-vous à l’API Globals Documentation.

  • Les résultats permettent aux tâches de stocker sorties auxquelles on peut accéder dans les itérations futures. Par exemple, le modèle les métriques, les paramètres formés ou les résultats intermédiaires peuvent être enregistrés et utilisé par les tâches suivantes. Pour plus de détails, reportez-vous à la Results Documentation de l’API.

API locale

L’API locale fournit aux tâches un accès aux données locales du nœud.les données peuvent inclure des ensembles de données pour la formation, des fichiers de configuration ou d’autres ressources stockées localement sur chaque nœud. L’API locale résume les mécanismes de stockage des données, permettant aux tâches de se concentrer sur le traitement des données plutôt que de gérer la récupération des données.

Reportez-vous à la documentation de l’API Local pour des informations détaillées sur l’interaction avec les données locales.

Exemple concret : l’apprentissage fédéré à travers un Swarm

Les exemples suivants illustrent un ensemble de tâches qui forment unApprentissage fédéré avec un Swarm. Le Swarm se compose de trois tâches principales :Aggregator, Scheduler et Worker. Ces tâches coordonner pour former un modèle partagé sur plusieurs nœuds sans partage les données brutes, préservant la confidentialité.

Tâche de travail (worker_task.py)

La tâche Worker entraîne un modèle local sur les données du nœud, l’évalue et renvoie les résultats au contexte global.

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from manta_light.task import Task
from manta_light.utils import bytes_to_numpy, numpy_to_bytes
from .model import MLP

class Worker(Task):
    def __init__(self):
        super().__init__()

        # Load MNIST dataset from local storage
        raw_data = self.local.get_binary_data("mnist_part.npz")
        self.data = np.load(raw_data)
        self.model = MLP()

    def run(self):
        # Get model weights and hyperparameters from globals
        weights = self.world.globals["global_model_params"]
        self.model.set_weights(bytes_to_numpy(weights))
        hyperparameters = self.world.globals["hyperparameters"]

        # Train the model and collect metrics
        metrics = self.train_model(hyperparameters)

        # Save the metrics and model weights to the results
        self.world.results.add("metrics", metrics)
        self.world.results.add(
            "model_params", numpy_to_bytes(self.model.get_weights())
        )

    def train_model(self, hyperparameters: dict):
        """
        Trains the model using the given hyperparameters.

        Parameters
        ----------
        hyperparameters : dict
            A dictionary containing training parameters like loss function,
            optimizer, and epochs.

        Returns
        -------
        dict
            A dictionary with training metrics including loss and validation
            accuracy.
        """
        X_train, y_train, X_test, y_test = (
            self.data["x_train"],
            self.data["y_train"],
            self.data["x_test"],
            self.data["y_test"],
        )

        # Define loss function and optimizer
        criterion = getattr(nn, hyperparameters["loss"])(
            **hyperparameters.get("loss_params", {})
        )
        optimizer = getattr(optim, hyperparameters["optimizer"])(
            self.model.parameters(), **hyperparameters.get("optimizer_params", {})
        )

        metrics = {"loss": [], "val_loss": [], "val_acc": []}

        # Training loop
        for epoch in range(hyperparameters["epochs"]):
            self.model.train()
            for i in range(0, X_train.shape[0], hyperparameters["batch_size"]):
                optimizer.zero_grad()
                output = self.model(
                    torch.tensor(X_train[i:i + hyperparameters["batch_size"]]).float()
                )
                loss = criterion(
                    output, torch.tensor(y_train[i:i + hyperparameters["batch_size"]])
                )
                loss.backward()
                optimizer.step()

            # Validation
            self.model.eval()
            with torch.no_grad():
                val_output = self.model(torch.tensor(X_test).float())
                val_loss = criterion(val_output, torch.tensor(y_test))
                val_acc = (
                  (val_output.argmax(1) == torch.tensor(y_test)).float().mean().item()
                )

            metrics["loss"].append(loss.item())
            metrics["val_loss"].append(val_loss.item())
            metrics["val_acc"].append(val_acc)

        return metrics

def main():
    Worker().run()

Tâche d’agrégation (aggregator.py)

La tâche Aggregator collecte les paramètres du modèle de tous les nœuds, les agrège et met à jour le modèle global.

import numpy as np
from manta_light.task import Task
from manta_light.utils import bytes_to_numpy, numpy_to_bytes

class Aggregator(Task):
    def run(self):
        # Aggregate the models from all nodes
        models = bytes_to_numpy(self.world.results.select("model_params"))
        aggregated_model = self.aggregate_models(list(models.values()))

        # Update the global model parameters
        self.world.globals["global_model_params"] = numpy_to_bytes(aggregated_model)

    def aggregate_models(self, models: list):
        """
        Aggregates the given models by averaging their weights.

        Parameters
        ----------
        models : list
            A list of model parameter dictionaries.

        Returns
        -------
        dict
            A dictionary representing the aggregated model parameters.
        """
        aggregated_model = {}
        for layer in models[0]:
            aggregated_model[layer] = np.mean(
                [model[layer] for model in models], axis=0
            )
        return aggregated_model

def main():
    Aggregator().run()

Tâche du planificateur (scheduler.py)

La tâche Scheduler sélectionne les nœuds pour la prochaine itération de formation en fonction de leur précision de validation et gère la planification d’un Swarm logique.

from manta_light.task import Task

class Scheduler(Task):
    def run(self):
        # Select nodes based on validation accuracy metrics
        metrics = self.world.results.select("metrics")
        selected_nodes = self.select_nodes(metrics)
        if selected_nodes:
            self.world.schedule_next_iter(
                node_ids=selected_nodes, task_to_schedule_alias="worker"
            )

    def select_nodes(self, metrics: dict):
        """
        Selects nodes based on their validation accuracy.

        Parameters
        ----------
        metrics : dict
            A dictionary containing metrics for each node.

        Returns
        -------
        list
            A list of node IDs that are selected for the next iteration.
        """
        selected_nodes = []
        val_acc_threshold = 0.95

        for node_id, metr in metrics.items():
            if metr["val_acc"][-1] < val_acc_threshold:
                selected_nodes.append(node_id)

        if not selected_nodes:
            self.world.stop_swarm()

        return selected_nodes

def main():
    Scheduler().run()

Cette section fournit un exemple concret de la manière dont les tâches interagissent avec les données locales et globales au sein du Swarm, formant un système fédéré complet boucle d’apprentissage qui exploite la nature décentralisée de la plateforme.