Swarm¶
Below is a code example that defines a federated learning swarm using the Manta platform. The FLSwarm
class demonstrates how to set up the worker, aggregator, and scheduler modules.
from pathlib import Path
from manta.module import Module
from manta.swarm import Swarm
from manta_light.utils import numpy_to_bytes
from modules.worker.model import MLP
class FLSwarm(Swarm):
def __init__(self):
super().__init__()
# Define the Aggregator module, which combines updates from all workers
self.aggregator = Module(
Path("modules/aggregator.py"),
"fl-pytorch-mnist:latest", # Docker image used for the aggregator
method="any", # Execution method
fixed=False,
maximum=1, # Only one aggregator is used
alias="aggregator",
)
# Define the Worker module, which handles local model training
self.worker = Module(
Path("modules/worker"),
"fl-pytorch-mnist:latest", # Docker image used for workers
alias="worker",
)
# Define the Scheduler module, which manages the swarm's iterations
self.scheduler = Module(
Path("modules/scheduler.py"),
"fl-pytorch-mnist:latest", # Docker image used for the scheduler
maximum=1, # Only one scheduler is used
alias="scheduler",
)
# Set global hyperparameters shared by all tasks in the swarm
self.set_global(
"hyperparameters",
{
"epochs": 1,
"batch_size": 32,
"loss": "CrossEntropyLoss",
"loss_params": {},
"optimizer": "SGD",
"optimizer_params": {"lr": 0.01, "momentum": 0.9},
},
)
# Initialize the global model parameters,
# converting them to bytes for transmission
self.set_global("global_model_params", numpy_to_bytes(MLP().get_weights()))
def execute(self):
"""
Define the execution flow of the swarm:
- Each iteration starts with the Worker.
- The results are then sent to the Aggregator.
- The Scheduler decides if the swarm should continue or stop based on convergence.
+--------+ +------------+ +-----------+ if has_converged
| Worker | --> | Aggregator | --> | Scheduler | ----------------> END PROGRAM
+--------+ +------------+ +-----------+
| | else
+--<<<----------<<<----------<<<----+
"""
m = self.worker() # Start with the worker task
m = self.aggregator(m) # Aggregate results from the workers
return self.scheduler(m) # Check for convergence or continue the loop
Warning
Do not forget to build and push your docker image :
docker build -t fl-pytorch-mnist:latest .
docker image push fl-pytorch-mnist:latest