Aggregator

La tâche Aggregator collecte les paramètres du modèle de tous les noeuds, les aggrège et met à jour les paramètres du modèle global.

# my_aggregator.py
import numpy as np
from manta_light.task import Task
from manta_light.utils import bytes_to_numpy, numpy_to_bytes

class Aggregator(Task):
    def run(self):
        # Aggregate the models from all nodes
        models = bytes_to_numpy(self.world.results.select("model_params"))
        aggregated_model = self.aggregate_models(list(models.values()))

        # Update the global model parameters
        self.world.globals["global_model_params"] = numpy_to_bytes(aggregated_model)

    def aggregate_models(self, models: list):
        """
        Aggregates the given models by averaging their weights.

        Parameters
        ----------
        models : list
            A list of model parameter dictionaries.

        Returns
        -------
        dict
            A dictionary representing the aggregated model parameters.
        """
        aggregated_model = {}
        for layer in models[0]:
            aggregated_model[layer] = np.mean(
                [model[layer] for model in models], axis=0
            )
        return aggregated_model

# Function for allowing task's execution
def main():
    Aggregator().run()