Aggregator¶
La tâche Aggregator
collecte les paramètres du modèle de tous les noeuds, les aggrège et met à jour les paramètres du modèle global.
# my_aggregator.py
import numpy as np
from manta_light.task import Task
from manta_light.utils import bytes_to_numpy, numpy_to_bytes
class Aggregator(Task):
def run(self):
# Aggregate the models from all nodes
models = bytes_to_numpy(self.world.results.select("model_params"))
aggregated_model = self.aggregate_models(list(models.values()))
# Update the global model parameters
self.world.globals["global_model_params"] = numpy_to_bytes(aggregated_model)
def aggregate_models(self, models: list):
"""
Aggregates the given models by averaging their weights.
Parameters
----------
models : list
A list of model parameter dictionaries.
Returns
-------
dict
A dictionary representing the aggregated model parameters.
"""
aggregated_model = {}
for layer in models[0]:
aggregated_model[layer] = np.mean(
[model[layer] for model in models], axis=0
)
return aggregated_model
# Function for allowing task's execution
def main():
Aggregator().run()