Worker¶
The Worker
task trains a local model on the node’s data, evaluates it, and sends the results back to the global context.
Model Definition¶
# my_worker/my_model.py
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Build the model
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.layers = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10),
)
def forward(self, x):
x = x.to(device)
x = x.view(x.size(0), -1)
x = self.layers(x)
return x
def set_weights(self, weights):
for layer, w in weights.items():
self.state_dict()[layer].copy_(torch.FloatTensor(w))
def get_weights(self):
weights = {}
for layer, w in self.state_dict().items():
weights[layer] = w.detach().cpu().numpy()
return weights
Task Definition¶
# my_worker/my_worker_task.py
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from manta_light.task import Task
from manta_light.utils import bytes_to_numpy, numpy_to_bytes
from .model import MLP
class Worker(Task):
def __init__(self):
super().__init__()
# Load MNIST dataset from local storage
raw_data = self.local.get_binary_data("mnist_part.npz")
self.data = np.load(raw_data)
self.model = MLP()
def run(self):
# Get model weights and hyperparameters from globals
weights = self.world.globals["global_model_params"]
self.model.set_weights(bytes_to_numpy(weights))
hyperparameters = self.world.globals["hyperparameters"]
# Train the model and collect metrics
metrics = self.train_model(hyperparameters)
# Save the metrics and model weights to the results
self.world.results.add("metrics", metrics)
self.world.results.add(
"model_params", numpy_to_bytes(self.model.get_weights())
)
def train_model(self, hyperparameters: dict):
"""
Trains the model using the given hyperparameters.
Parameters
----------
hyperparameters : dict
A dictionary containing training parameters like loss function,
optimizer, and epochs.
Returns
-------
dict
A dictionary with training metrics including loss and validation
accuracy.
"""
X_train, y_train, X_test, y_test = (
self.data["x_train"],
self.data["y_train"],
self.data["x_test"],
self.data["y_test"],
)
# Define loss function and optimizer
criterion = getattr(nn, hyperparameters["loss"])(
**hyperparameters.get("loss_params", {})
)
optimizer = getattr(optim, hyperparameters["optimizer"])(
self.model.parameters(), **hyperparameters.get("optimizer_params", {})
)
metrics = {"loss": [], "val_loss": [], "val_acc": []}
# Training loop
for epoch in range(hyperparameters["epochs"]):
self.model.train()
for i in range(0, X_train.shape[0], hyperparameters["batch_size"]):
optimizer.zero_grad()
output = self.model(
torch.tensor(X_train[i:i + hyperparameters["batch_size"]]).float()
)
loss = criterion(
output, torch.tensor(y_train[i:i + hyperparameters["batch_size"]])
)
loss.backward()
optimizer.step()
# Validation
self.model.eval()
with torch.no_grad():
val_output = self.model(torch.tensor(X_test).float())
val_loss = criterion(val_output, torch.tensor(y_test))
val_acc = (
(val_output.argmax(1) == torch.tensor(y_test)).float().mean().item()
)
metrics["loss"].append(loss.item())
metrics["val_loss"].append(val_loss.item())
metrics["val_acc"].append(val_acc)
return metrics
Allowing Manta to execute the task¶
# my_worker/__init__.py
from .worker_task import Worker
# Function for allowing Manta to execute the task
# It must be in "__init__.py" if the task is a **folder**.
def main():
Worker().run()