Worker

The Worker task trains a local model on the node’s data, evaluates it, and sends the results back to the global context.

Model Definition

# my_worker/my_model.py
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Build the model
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10),
        )

    def forward(self, x):
        x = x.to(device)
        x = x.view(x.size(0), -1)
        x = self.layers(x)
        return x

    def set_weights(self, weights):
        for layer, w in weights.items():
            self.state_dict()[layer].copy_(torch.FloatTensor(w))

    def get_weights(self):
        weights = {}
        for layer, w in self.state_dict().items():
            weights[layer] = w.detach().cpu().numpy()
        return weights

Task Definition

# my_worker/my_worker_task.py
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from manta_light.task import Task
from manta_light.utils import bytes_to_numpy, numpy_to_bytes
from .model import MLP

class Worker(Task):
    def __init__(self):
        super().__init__()

        # Load MNIST dataset from local storage
        raw_data = self.local.get_binary_data("mnist_part.npz")
        self.data = np.load(raw_data)
        self.model = MLP()

    def run(self):
        # Get model weights and hyperparameters from globals
        weights = self.world.globals["global_model_params"]
        self.model.set_weights(bytes_to_numpy(weights))
        hyperparameters = self.world.globals["hyperparameters"]

        # Train the model and collect metrics
        metrics = self.train_model(hyperparameters)

        # Save the metrics and model weights to the results
        self.world.results.add("metrics", metrics)
        self.world.results.add(
            "model_params", numpy_to_bytes(self.model.get_weights())
        )

    def train_model(self, hyperparameters: dict):
        """
        Trains the model using the given hyperparameters.

        Parameters
        ----------
        hyperparameters : dict
            A dictionary containing training parameters like loss function,
            optimizer, and epochs.

        Returns
        -------
        dict
            A dictionary with training metrics including loss and validation
            accuracy.
        """
        X_train, y_train, X_test, y_test = (
            self.data["x_train"],
            self.data["y_train"],
            self.data["x_test"],
            self.data["y_test"],
        )

        # Define loss function and optimizer
        criterion = getattr(nn, hyperparameters["loss"])(
            **hyperparameters.get("loss_params", {})
        )
        optimizer = getattr(optim, hyperparameters["optimizer"])(
            self.model.parameters(), **hyperparameters.get("optimizer_params", {})
        )

        metrics = {"loss": [], "val_loss": [], "val_acc": []}

        # Training loop
        for epoch in range(hyperparameters["epochs"]):
            self.model.train()
            for i in range(0, X_train.shape[0], hyperparameters["batch_size"]):
                optimizer.zero_grad()
                output = self.model(
                    torch.tensor(X_train[i:i + hyperparameters["batch_size"]]).float()
                )
                loss = criterion(
                    output, torch.tensor(y_train[i:i + hyperparameters["batch_size"]])
                )
                loss.backward()
                optimizer.step()

            # Validation
            self.model.eval()
            with torch.no_grad():
                val_output = self.model(torch.tensor(X_test).float())
                val_loss = criterion(val_output, torch.tensor(y_test))
                val_acc = (
                  (val_output.argmax(1) == torch.tensor(y_test)).float().mean().item()
                )

            metrics["loss"].append(loss.item())
            metrics["val_loss"].append(val_loss.item())
            metrics["val_acc"].append(val_acc)

        return metrics

Allowing Manta to execute the task

# my_worker/__init__.py
from .worker_task import Worker

# Function for allowing Manta to execute the task
# It must be in "__init__.py" if the task is a **folder**.
def main():
    Worker().run()