US-12619855-B2 - Model pool for multimodal distributed learning
Abstract
A method performed by a central server node is provided. The method includes: receiving local model weights and corresponding key from a local client node; and updating a model pool having a plurality of central models and corresponding keys associated with each of the central models. Updating the model pool is based on the local model weights, and one or more of the key corresponding to the local client node and the keys collectively corresponding to each of the central models. Updating the model pool comprises updating at least two of the plurality of central models contained in the model pool.
Inventors
- Jean Paulo MARTINS
Assignees
- TELEFONAKTIEBOLAGET LM ERICSSON (PUBL)
Dates
- Publication Date
- 20260505
- Application Date
- 20200117
Claims (16)
- 1 . A method performed by a central server node, the method comprising: receiving local model weights and a corresponding key from a local client node; and updating a model pool having a plurality of central models and corresponding keys associated with each of the central models, wherein updating the model pool is based on the local model weights, and one or more of the key corresponding to the local client node and the keys collectively corresponding to each of the central models, and wherein updating the model pool comprises updating at least two of the plurality of central models contained in the model pool, wherein the central server node is a base station in a 5G communications network; wherein the local client node is a mobile device communicatively coupled with the base station via radio signal through the 5G communications network; wherein updating the model pool further comprises computing a similarity score comparing the key corresponding to the local client node to the keys collectively corresponding to each of the plurality of central models, and wherein updating at least two of the plurality of central models is based on the similarity score, and wherein updating at least two of the plurality of central models comprises computing new model weights Θ i new = ( Θ i old + Θ _ i ) / 2 for each i in a set of indexes corresponding to the at least two of the plurality of central models, where: Θ n e w i refers to the new model weights for the central model corresponding to the index i, Θ o l d i refers to old model weights for the central model corresponding to the index i, Θ ¯ i = ( 1 - w i ) Θ i o l d + w i Θ ¯ . W=(w 1 , . . . , w M ) is the similarity score, such that w i corresponds to the similarity between the key corresponding to the local client node and the key corresponding to the i-th central model of the plurality of central models, M is the number of the plurality of central models in the model pool, and Θ refers to the local model weights.
- 2 . A method performed by a central server node, the method comprising: receiving a request for a model from a local client node, wherein the request includes a key corresponding to the local client node; in response to receiving the request, constructing a model from a model pool having a plurality of central models and corresponding keys associated with each of the central models; and sending the constructed model to the local client node, wherein constructing the model from the model pool is based on the key corresponding to the local client node and the keys collectively corresponding to each of the central models, and wherein constructing the model from the model pool comprises aggregating at least two of the plurality of central models contained in the model pool, wherein the central server node is a base station in a 5G communications network; wherein the local client node is a mobile device communicatively coupled with the base station via radio signal through the 5G communications network; wherein constructing the model from the model pool further comprises computing a similarity score comparing the key corresponding to the local client node to the keys collectively corresponding to each of the plurality of central models, and wherein aggregating at least two of the plurality of central models is based on the similarity score, and wherein aggregating at least two of the plurality of central models comprises computing aggregated model weights θ agg = 1 M ∑ i = 1 M w i θ i , where: Θ agg i refers to the aggregated model weights for the constructed model, M is the number of the plurality of central models in the model pool, θ i refers to model weights for the central model corresponding to the index i, and W=(w 1 , . . . , w M ) is the similarity score, such that w i corresponds to the similarity between the key corresponding to the local client node and the key corresponding to the i-th central model of the plurality of central models.
- 3 . A central server node, the central server node comprising processing circuitry and a memory containing instructions executable by the processing circuitry, whereby the processing circuitry is operable to: receive local model weights and a corresponding key from a local client node; and update a model pool having a plurality of central models and corresponding keys associated with each of the central models, wherein updating the model pool is based on the local model weights, and one or more of the key corresponding to the local client node and the keys collectively corresponding to each of the central models, wherein updating the model pool comprises updating at least two of the plurality of central models contained in the model pool, wherein the central server node is a base station in a 5G communications network; wherein the local client node is a mobile device communicatively coupled with the base station via radio signal through the 5G communications network; wherein updating the model pool further comprises computing a similarity score comparing the key corresponding to the local client node to the keys collectively corresponding to each of the plurality of central models, and wherein updating at least two of the plurality of central models is based on the similarity score, and wherein updating at least two of the plurality of central models comprises computing new model weights θ i n e w = ( Θ i old + Θ ¯ i ) / 2 for each i in a set of indexes corresponding to the at least two of the plurality of central models, where: Θ n e w i refers to the new model weights for the central model corresponding to the index i, Θ i o l d refers to old model weights for the central model corresponding to the index i, Θ ¯ i = ( 1 - w i ) Θ i o l d + w i Θ ¯ . W=(w 1 , . . . , w M ) is the similarity score, such that w i corresponds to the similarity between the key corresponding to the local client node and the key corresponding to the i-th central model of the plurality of central models, M is the number of the plurality of central models in the model pool, and Θ refers to the local model weights.
- 4 . The central server node of claim 3 , wherein the processor is further configured to select the at least two of the plurality of central models from the model pool based on the similarity score.
- 5 . The central server node of claim 3 , wherein updating at least two of the plurality of central models comprises, for each central model of the at least two of the plurality of central models, combining the local model weights with a current version of the central model such that the local model weights are given a greater weight for central models having a higher corresponding similarity score than for central models having a lower corresponding similarity score.
- 6 . The central server node of claim 3 , wherein the similarity score W is normalized such that Σw i =1.
- 7 . The central server node of claim 3 , wherein updating the model pool is performed as a result of receiving local model weights and a corresponding key, without needing to wait on additional local model weights and corresponding keys from additional local client nodes.
- 8 . The central server node of claim 3 , wherein the processor is further configured to update the model pool such that a distribution of the keys corresponding to the central models of the model pool is modified based on the corresponding key from the local client node.
- 9 . A central server node, the central server node comprising processing circuitry and a memory containing instructions executable by the processing circuitry, whereby the processing circuitry is operable to: receive a request for a model from a local client node, wherein the request includes a key corresponding to the local client node; in response to receiving the request, construct a model from a model pool having a plurality of central models and corresponding keys associated with each of the central models; and send the constructed model to the local client node, wherein constructing the model from the model pool is based on the key corresponding to the local client node and the keys collectively corresponding to each of the central models, wherein constructing the model from the model pool comprises aggregating at least two of the plurality of central models contained in the model pool, wherein the central server node is a base station in a 5G communications network; wherein the local client node is a mobile device communicatively coupled with the base station via radio signal through the 5G communications network; wherein constructing the model from the model pool further comprises computing a similarity score comparing the key corresponding to the local client node to the keys collectively corresponding to each of the plurality of central models, and wherein aggregating at least two of the plurality of central models is based on the similarity score, and wherein aggregating at least two of the plurality of central models comprises computing aggregated model weights θ agg = 1 M ∑ i = 1 M w i θ i , where: Θ i agg refers to the aggregated model weights for the constructed model, M is the number of the plurality of central models in the model pool, θ i refers to model weights for the central model corresponding to the index i, and W=(w 1 , . . . , w M ) is the similarity score, such that w i corresponds to the similarity between the key corresponding to the local client node and the key corresponding to the i-th central model of the plurality of central models.
- 10 . The central server node of claim 9 , wherein the at least two of the plurality of central models from the model pool comprise all of the plurality of central models.
- 11 . The central server node of claim 9 , wherein aggregating at least two of the plurality of central models is performed such that the central models of the at least two of the plurality of central models having a higher corresponding similarity score are weighted more than models having a lower corresponding similarity score.
- 12 . The central server node of claim 9 , wherein the similarity score W is normalized such that Σw i =1.
- 13 . The central server node of claim 9 , wherein one or more keys of the key corresponding to the local client node and the keys of the model pool corresponding collectively to each of the plurality of central models, each includes a data distribution part and a deployment part, wherein the data distribution part of each of the one or more keys includes information describing a data distribution corresponding to the respective local client node or central model, and wherein the deployment part of each of the one or more keys includes information describing a deployment environment corresponding to the respective local client node or central model.
- 14 . The central server node of claim 13 , wherein computing a similarity score is based on a similarity function d(k 1 , k 2 ) between a first key k 1 having data distribution part k 1 D and deployment part k 1 S and a second key k 2 having data distribution part k 2 D and deployment part k 2 S , such that d ( k 1 , k 2 ) = d D ( k 1 D , k 2 D ) + d S ( k 1 S , k 2 S ) .
- 15 . The central server node of claim 9 , wherein the processor is further configured to update the model pool such that a distribution of the keys corresponding to the central models of the model pool is modified based on the corresponding key from the local client node.
- 16 . The central server node according to claim 3 , for enabling accurate distributed machine learning.
Description
CROSS REFERENCE TO RELATED APPLICATION(S) This application is a 35 U.S.C. § 371 National Phase Entry Application from PCT/EP2020/051196, filed Jan. 17, 2020, designating the United States, the disclosure of which is incorporated herein by reference in its entirety. TECHNICAL FIELD Disclosed are embodiments related to distributed learning; and, in particular, to a model pool for multimodal distributed learning. BACKGROUND Distributed learning, involving decentralized data, is a technique of machine learning that addresses various data-related constraints (e.g., privacy, bandwidth), that prohibit or otherwise limit the transference of local training data from local client nodes to a central node for centralized processing. In such scenarios, learning from data occurs locally, with updates being sent from local client nodes to a central server node so that different local client nodes can affect the global model without having to sacrifice the security or privacy of their local training data. Consider as an example the problem of predicting a word from a prefix typed in a mobile device. Local devices equipped with a learning procedure can model a user's typing behavior to suggest or recommend a suffix to complete the partially-typed-in word. Considering many users that share a common language, the resulting models from each local device can be combined (e.g., averaged) to produce a global model that is still representative of the problem. Federated learning (a type of distributed learning) is a technique that applies to this type of problem, where multiple local models can be averaged to create an accurate global model. Federated learning works in rounds, where each round comprises: At a central server node, selecting a subset of participating nodes.At the central server node, sending the current global model to the selected nodes.At each of the selected nodes, training the global model locally on a local dataset.At each of the selected nodes, sending updates back to the central server node reflecting the local training.At the central server node, aggregating the updates and applying them to the global model. Continuing with the previous example of word prediction, now consider multiple users typing in different languages. In such an example, the training data distributions may vary considerably from one user to another. Applying the traditional federated learning approach in this example undoubtedly results in undesired accuracy degradation, i.e., updates learned from one language might contradict updates from another. Of course, in this example it may suffice to maintain a different global model for each language. However, in more general real-world scenarios, such a prior separation might be hard or impossible to identify. SUMMARY Problems arise in the case of differently distributed training data among local client nodes. For example, typical federated learning techniques implement model composition through averaging, which is not robust to scenarios where local data distributions are far from independent and identically distributed (IID). Typical federated learning techniques support only one central model that is improved using updates from the local models, which imposes challenges for non-overlapping, non-IID data distributions. Given either a target scenario of deployment or a target data distribution, typical federated learning techniques can only provide a central model which, if the local target distribution is significantly different from other local client nodes, is unlikely to be a good starting point. Some federated learning techniques apply clustering in order to cluster local client nodes into different groups. Such techniques are limited and do not fully address the issues noted above, among other problems they have. For example, such clustering approaches require the specification of the number of clusters in which to classify the local client nodes, must be updated to reflect the current situation if data distributions of local client nodes varies over time, and are also limited in how such approaches cluster local client nodes, relying on local training data but no other local client resources. Further, such clustering techniques are an all-or-nothing approach, where if two local client nodes have similar data distribution but belong to different clusters, updates from one of the local client nodes will not be propagated to the other. Accordingly, improved learning techniques are needed. Embodiments disclosed herein provide improved machine learning in a distributed learning environment, in situations where local training data have significant differences in their distribution. For example, such significant differences in distribution of local training data could arise from intrinsic local characteristics of the different local client nodes (e.g., hardware, software, geographical location, system type). Regardless of the reason for differing distributions, if such differences are signific