Swarm¶
Ci-dessous est l’exemple de code qui définit un Swarm d’apprentissage fédéré utilisant la plateforme Manta. Cette classe FLSwarm
illustre comment mettre en place le module worker, le module aggregator et le module scheduler.
from pathlib import Path
from manta.module import Module
from manta.swarm import Swarm
from manta_light.utils import numpy_to_bytes
from modules.worker.model import MLP
class FLSwarm(Swarm):
def __init__(self):
super().__init__()
# Define the Aggregator module, which combines updates from all workers
self.aggregator = Module(
Path("modules/aggregator.py"),
"fl-pytorch-mnist:latest", # Docker image used for the aggregator
method="any", # Execution method
fixed=False,
maximum=1, # Only one aggregator is used
alias="aggregator",
)
# Define the Worker module, which handles local model training
self.worker = Module(
Path("modules/worker"),
"fl-pytorch-mnist:latest", # Docker image used for workers
alias="worker",
)
# Define the Scheduler module, which manages the swarm's iterations
self.scheduler = Module(
Path("modules/scheduler.py"),
"fl-pytorch-mnist:latest", # Docker image used for the scheduler
maximum=1, # Only one scheduler is used
alias="scheduler",
)
# Set global hyperparameters shared by all tasks in the swarm
self.set_global(
"hyperparameters",
{
"epochs": 1,
"batch_size": 32,
"loss": "CrossEntropyLoss",
"loss_params": {},
"optimizer": "SGD",
"optimizer_params": {"lr": 0.01, "momentum": 0.9},
},
)
# Initialize the global model parameters,
# converting them to bytes for transmission
self.set_global("global_model_params", numpy_to_bytes(MLP().get_weights()))
def execute(self):
"""
Define the execution flow of the swarm:
- Each iteration starts with the Worker.
- The results are then sent to the Aggregator.
- The Scheduler decides if the swarm should continue or stop based on convergence.
+--------+ +------------+ +-----------+ if has_converged
| Worker | --> | Aggregator | --> | Scheduler | ----------------> END PROGRAM
+--------+ +------------+ +-----------+
| | else
+--<<<----------<<<----------<<<----+
"""
m = self.worker() # Start with the worker task
m = self.aggregator(m) # Aggregate results from the workers
return self.scheduler(m) # Check for convergence or continue the loop
Avertissement
N’oubliez pas de build et push votre image docker :
docker build -t fl-pytorch-mnist:latest .
docker image push fl-pytorch-mnist:latest