Developing a Task with manta-light ================================== Tasks are the core building blocks of the decentralized algorithms in Manta. A :code:`Task` defines the specific computations that are executed on each node in a Swarm. Using the :code:`manta-light` API, tasks can interact with the node’s local data, access global parameters, and contribute results back to the shared context of the Swarm. For more information, refer to the :class:`Task ` API Documentation. .. image:: ../_static/images/light-task.png :alt: Task :align: center :class: only-light .. image:: ../_static/images/dark-task.png :alt: Task :align: center :class: only-dark World, Globals, and Results --------------------------- The :code:`manta-light` API provides three main interfaces that are essential for task development: :class:`World `, :class:`Globals `, and :class:`Results `. - :class:`World ` provides the execution context for tasks. It allows tasks to interact with shared global variables (:class:`Globals `) and store task-specific outputs (:class:`Results `). The :class:`World ` serves as the communication layer between tasks, making it possible to synchronize and share data across different nodes. For more details, refer to the :class:`World ` API Documentation. - :class:`Globals ` are global variables shared across all tasks and iterations within the Swarm. These variables can include parameters such as hyperparameters for training or other configuration settings that need to be accessible to all nodes. For more details, refer to the :class:`Globals ` API Documentation. - :class:`Results ` allow tasks to store outputs that can be accessed in future iterations. For instance, model metrics, trained parameters, or intermediate results can be saved and used by subsequent tasks. For more details, refer to the :class:`Results ` API Documentation. Local API --------- The Local API provides tasks with access to the node's local data. This data can include datasets for training, configuration files, or other resources that are stored locally on each node. The Local API abstracts the data storage mechanisms, allowing tasks to focus on processing the data rather than managing data retrieval. Refer to the :class:`Local ` API Documentation for detailed information on interacting with local data. Concrete Example: Federated Learning Swarm ------------------------------------------ The following examples demonstrate a set of tasks that form a simple Federated Learning Swarm. The swarm consists of three main tasks: :code:`Aggregator`, :code:`Scheduler`, and :code:`Worker`. These tasks coordinate to train a shared model across multiple nodes without sharing the raw data, preserving privacy. Worker Task (:code:`worker_task.py`) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The :code:`Worker` task trains a local model on the node's data, evaluates it, and sends the results back to the global context. .. code-block:: python import numpy as np import torch import torch.nn as nn import torch.optim as optim from manta_light.task import Task from manta_light.utils import bytes_to_numpy, numpy_to_bytes from .model import MLP class Worker(Task): def __init__(self): super().__init__() # Load MNIST dataset from local storage raw_data = self.local.get_binary_data("mnist_part.npz") self.data = np.load(raw_data) self.model = MLP() def run(self): # Get model weights and hyperparameters from globals weights = self.world.globals["global_model_params"] self.model.set_weights(bytes_to_numpy(weights)) hyperparameters = self.world.globals["hyperparameters"] # Train the model and collect metrics metrics = self.train_model(hyperparameters) # Save the metrics and model weights to the results self.world.results.add("metrics", metrics) self.world.results.add( "model_params", numpy_to_bytes(self.model.get_weights()) ) def train_model(self, hyperparameters: dict): """ Trains the model using the given hyperparameters. Parameters ---------- hyperparameters : dict A dictionary containing training parameters like loss function, optimizer, and epochs. Returns ------- dict A dictionary with training metrics including loss and validation accuracy. """ X_train, y_train, X_test, y_test = ( self.data["x_train"], self.data["y_train"], self.data["x_test"], self.data["y_test"], ) # Define loss function and optimizer criterion = getattr(nn, hyperparameters["loss"])( **hyperparameters.get("loss_params", {}) ) optimizer = getattr(optim, hyperparameters["optimizer"])( self.model.parameters(), **hyperparameters.get("optimizer_params", {}) ) metrics = {"loss": [], "val_loss": [], "val_acc": []} # Training loop for epoch in range(hyperparameters["epochs"]): self.model.train() for i in range(0, X_train.shape[0], hyperparameters["batch_size"]): optimizer.zero_grad() output = self.model( torch.tensor(X_train[i:i + hyperparameters["batch_size"]]).float() ) loss = criterion( output, torch.tensor(y_train[i:i + hyperparameters["batch_size"]]) ) loss.backward() optimizer.step() # Validation self.model.eval() with torch.no_grad(): val_output = self.model(torch.tensor(X_test).float()) val_loss = criterion(val_output, torch.tensor(y_test)) val_acc = ( (val_output.argmax(1) == torch.tensor(y_test)).float().mean().item() ) metrics["loss"].append(loss.item()) metrics["val_loss"].append(val_loss.item()) metrics["val_acc"].append(val_acc) return metrics def main(): Worker().run() Aggregator Task (:code:`aggregator.py`) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The :code:`Aggregator` task collects model parameters from all nodes, aggregates them, and updates the global model. .. code-block:: python import numpy as np from manta_light.task import Task from manta_light.utils import bytes_to_numpy, numpy_to_bytes class Aggregator(Task): def run(self): # Aggregate the models from all nodes models = bytes_to_numpy(self.world.results.select("model_params")) aggregated_model = self.aggregate_models(list(models.values())) # Update the global model parameters self.world.globals["global_model_params"] = numpy_to_bytes(aggregated_model) def aggregate_models(self, models: list): """ Aggregates the given models by averaging their weights. Parameters ---------- models : list A list of model parameter dictionaries. Returns ------- dict A dictionary representing the aggregated model parameters. """ aggregated_model = {} for layer in models[0]: aggregated_model[layer] = np.mean( [model[layer] for model in models], axis=0 ) return aggregated_model def main(): Aggregator().run() Scheduler Task (:code:`scheduler.py`) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The :code:`Scheduler` task selects nodes for the next training iteration based on their validation accuracy and manages the Swarm’s scheduling logic. .. code-block:: python from manta_light.task import Task class Scheduler(Task): def run(self): # Select nodes based on validation accuracy metrics metrics = self.world.results.select("metrics") selected_nodes = self.select_nodes(metrics) if selected_nodes: self.world.schedule_next_iter( node_ids=selected_nodes, task_to_schedule_alias="worker" ) def select_nodes(self, metrics: dict): """ Selects nodes based on their validation accuracy. Parameters ---------- metrics : dict A dictionary containing metrics for each node. Returns ------- list A list of node IDs that are selected for the next iteration. """ selected_nodes = [] val_acc_threshold = 0.95 for node_id, metr in metrics.items(): if metr["val_acc"][-1] < val_acc_threshold: selected_nodes.append(node_id) if not selected_nodes: self.world.stop_swarm() return selected_nodes def main(): Scheduler().run() This section provides a concrete example of how tasks interact with the local and global data within the Swarm, forming a complete federated learning loop that leverages the decentralized nature of the platform.