Aggregator ========== The :code:`Aggregator` task collects model parameters from all nodes, aggregates them, and updates the global model parameters. .. code:: python # my_aggregator.py import numpy as np from manta_light.task import Task from manta_light.utils import bytes_to_numpy, numpy_to_bytes class Aggregator(Task): def run(self): # Aggregate the models from all nodes models = bytes_to_numpy(self.world.results.select("model_params")) aggregated_model = self.aggregate_models(list(models.values())) # Update the global model parameters self.world.globals["global_model_params"] = numpy_to_bytes(aggregated_model) def aggregate_models(self, models: list): """ Aggregates the given models by averaging their weights. Parameters ---------- models : list A list of model parameter dictionaries. Returns ------- dict A dictionary representing the aggregated model parameters. """ aggregated_model = {} for layer in models[0]: aggregated_model[layer] = np.mean( [model[layer] for model in models], axis=0 ) return aggregated_model # Function for allowing task's execution def main(): Aggregator().run()