US-12621345-B2 - Data protection for machine learning models trained on client data
Abstract
Methods and systems are described herein for protecting client data while training machine learning models. The system may transmit, to client devices, simple models to be trained on a respective client device to generate predictions based on a respective subset of respective client data of the respective client device. The system may receive the trained simple models from the client devices. The system may input, into an ensemble model including the simple models, an unlabeled synthetic dataset. This may cause the ensemble model to aggregate a set of predictions generated by each simple model to generate labels for the unlabeled synthetic dataset. The system may then input, into a new model, the unlabeled synthetic dataset and the labels to train the new model to predict the labels for the unlabeled synthetic dataset.
Inventors
- Jeremy Goodsitt
- Michael Davis
- Taylor TURNER
- Kenny BEAN
- Tyler FARNAN
Assignees
- CAPITAL ONE SERVICES, LLC
Dates
- Publication Date
- 20260505
- Application Date
- 20240909
Claims (20)
- 1 . A system for protecting client data from malicious actors while training machine learning models on aspects of the client data, the system comprising: one or more processors; and one or more non-transitory, computer-readable media having computer-executable instructions stored thereon that, when executed by the one or more processors, cause the system to perform operations comprising: transmitting, to each client device of a plurality of client devices, a set of untrained simple models, wherein each untrained simple model has been generated to be trained on a respective client device using a subset of a plurality of features of respective client data on the respective client device; receiving, from the plurality of client devices, a plurality of trained simple models, wherein each trained simple model is trained to generate predictions based on a respective subset of the plurality of features of the respective client data of the respective client device; retrieving a synthetic dataset that is unlabeled, the synthetic dataset generated based on the plurality of features of aggregate client data of the plurality of client devices; inputting, into each trained simple model of the plurality of trained simple models, a respective subset of the synthetic dataset according to the respective subset of the plurality of features on which the trained simple model is trained to cause each trained simple model to generate a set of predictions for the synthetic dataset based on the respective subset of the plurality of features; aggregating the set of predictions generated by each trained simple model of the plurality of trained simple models to generate a plurality of labels for the synthetic dataset; inputting, into a new model, the synthetic dataset and the plurality of labels to train the new model to predict the plurality of labels for the synthetic dataset; and transmitting a plurality of copies of the new model to the plurality of client devices to cause each copy of the new model to deploy on a client device of the plurality of client devices.
- 2 . A method for protecting client data from malicious actors while training machine learning models, the method comprising: transmitting, to a plurality of client devices, a plurality of simple models, wherein each simple model is generated to be trained on a respective client device to generate predictions based on a respective subset of respective client data of the respective client device; receiving, from the plurality of client devices, the plurality of simple models; inputting, into an ensemble model comprising the plurality of simple models, an unlabeled synthetic dataset to cause each simple model to generate a set of predictions for the unlabeled synthetic dataset based on a respective subset of the unlabeled synthetic dataset, wherein the respective subset of the unlabeled synthetic dataset corresponds to the respective subset of the respective client data on which the simple model is trained; causing the ensemble model to aggregate the set of predictions generated by each simple model to generate a plurality of labels for the unlabeled synthetic dataset; and inputting, into a new model, the unlabeled synthetic dataset and the plurality of labels to train the new model to predict the plurality of labels for the unlabeled synthetic dataset.
- 3 . The method of claim 2 , further comprising transmitting a copy of the new model to each client device of the plurality of client devices with a request to cause the copy of the new model to deploy on the respective client device.
- 4 . The method of claim 3 , wherein transmitting the copy of the new model to each client device comprises transmitting, to each client device, a command to perform testing of the copy using the respective client data of the respective client device and to return results of the testing.
- 5 . The method of claim 4 , further comprising receiving, from a first client device of the plurality of client devices, first results indicating that a first accuracy associated with a first copy of the new model deployed on the first client device does not satisfy an accuracy threshold.
- 6 . The method of claim 5 , further comprising, in response to receiving the first results indicating that the first accuracy does not satisfy the accuracy threshold: applying first weights to a first set of predictions generated by a first simple model associated with the first client device and applying second weights to other sets of predictions generated by other simple models of the plurality of simple models, wherein the first weights are higher than the second weights; aggregating the set of predictions generated by each simple model according to the first weights and the second weights to generate an updated plurality of labels for the unlabeled synthetic dataset; inputting, into the new model, the unlabeled synthetic dataset and the updated plurality of labels to update the new model to predict the updated plurality of labels for the unlabeled synthetic dataset; and transmitting an updated copy of the new model to the first client device to cause the updated copy of the new model to deploy on the first client device.
- 7 . The method of claim 5 , further comprising, in response to receiving the first results indicating that the first accuracy does not satisfy the accuracy threshold: applying first weights to a first set of predictions generated by a first simple model associated with the first client device and applying second weights to other sets of predictions generated by other simple models of the plurality of simple models, wherein the first weights are lower than the second weights; aggregating the set of predictions generated by each simple model according to the first weights and the second weights to generate an updated plurality of labels for the unlabeled synthetic dataset; inputting, into the new model, the unlabeled synthetic dataset and the updated plurality of labels to update the new model to predict the updated plurality of labels for the unlabeled synthetic dataset; and transmitting an updated copy of the new model to the first client device to cause the updated copy of the new model to deploy on the first client device.
- 8 . The method of claim 5 , further comprising, in response to receiving the first results indicating that the first accuracy does not satisfy the accuracy threshold: determining that first client data from the first client device comprises one or more features that are not included in the unlabeled synthetic dataset; generating an updated unlabeled synthetic dataset comprising the one or more features from the first client data; inputting, into the ensemble model, the updated unlabeled synthetic dataset to cause each simple model to generate an updated set of predictions for the updated unlabeled synthetic dataset based on an updated respective subset of the updated unlabeled synthetic dataset; aggregating the updated set of predictions generated by each simple model to generate an updated plurality of labels for the updated unlabeled synthetic dataset; inputting, into the new model, the updated unlabeled synthetic dataset and the updated plurality of labels to train the new model to predict the updated plurality of labels for the updated unlabeled synthetic dataset; and transmitting an updated copy of the new model to the first client device to cause the updated copy of the new model to deploy on the first client device.
- 9 . The method of claim 2 , further comprising: receiving, from the plurality of client devices, feature data indicating a plurality of features of aggregate client data of the plurality of client devices; determining, based on the feature data, one or more relationships between the plurality of features; separating the plurality of features into a first subset of features and a second subset of features based on the one or more relationships between the plurality of features such that related features are separated between the first subset of features and the second subset of features; and assigning, for each client device, the first subset of features to a first simple model and the second subset of features to a second simple model such that the first simple model is trained to generate a first set of predictions based on the first subset of features and the second simple model is trained to generate a second set of predictions based on the second subset of features.
- 10 . The method of claim 2 , further comprising: receiving, from the plurality of client devices, entry data indicating a plurality of entries of aggregate client data from the plurality of client devices; determining, based on the entry data, one or more relationships between the plurality of entries; separating the plurality of entries into a first subset of entries and a second subset of entries based on the one or more relationships between the plurality of entries such that related entries are separated between the first subset of entries and the second subset of entries; and assigning, for each client device, the first subset of entries to a first simple model and the second subset of entries to a second simple model such that the first simple model is trained to generate a first set of predictions based on the first subset of entries and the second simple model is trained to generate a second set of predictions based on the second subset of entries.
- 11 . The method of claim 2 , further comprising: receiving, from the plurality of client devices, indications of a first type of data and a second type of data of aggregate client data of the plurality of client devices; separating a plurality of features of the client data into a first subset of features having the first type of data and a second subset of features having the second type of data; and assigning, for each client device, the first subset of features to a first simple model and the second subset of features to a second simple model such that the first simple model is trained to generate a first set of predictions based on the first subset of features and the second simple model is trained to generate a second set of predictions based on the second subset of features.
- 12 . The method of claim 2 , further comprising transmitting, to each client device of the plurality of client devices, a set of simple models of the plurality of simple models to cause each simple model to train to generate the predictions based on the respective subset of respective client data of the respective client device.
- 13 . One or more non-transitory, computer-readable media storing instructions that, when executed by one or more processors, cause operations comprising: transmitting, to a plurality of client devices, a plurality of simple models, wherein each simple model is generated to be trained on a respective client device to generate predictions based on a respective subset of respective client data of the respective client device; receiving, from the plurality of client devices, the plurality of simple models; inputting, into an ensemble model comprising the plurality of simple models, an unlabeled synthetic dataset to cause the ensemble model to aggregate a set of predictions generated by each simple model to generate a plurality of labels for the unlabeled synthetic dataset; and inputting, into a new model, the unlabeled synthetic dataset and the plurality of labels to train the new model to predict the plurality of labels for the unlabeled synthetic dataset.
- 14 . The one or more non-transitory, computer-readable media of claim 13 , wherein inputting the unlabeled synthetic dataset into the ensemble model comprises causing each simple model to generate the set of predictions for the unlabeled synthetic dataset based on a respective subset of the unlabeled synthetic dataset, wherein the respective subset of the unlabeled synthetic dataset corresponds to the respective subset of the respective client data on which the simple model is trained.
- 15 . The one or more non-transitory, computer-readable media of claim 13 , wherein the instructions further cause the one or more processors to perform operations comprising transmitting a copy of the new model to each client device of the plurality of client devices with a request to cause the copy of the new model to deploy on the respective client device.
- 16 . The one or more non-transitory, computer-readable media of claim 15 , wherein transmitting the copy of the new model to each client device comprises transmitting, to each client device, a command to perform testing of the copy using the respective client data of the respective client device and to return results of the testing.
- 17 . The one or more non-transitory, computer-readable media of claim 16 , wherein the instructions further cause the one or more processors to perform operations comprising receiving, from a first client device of the plurality of client devices, first results indicating that a first accuracy associated with a first copy of the new model deployed on the first client device does not satisfy an accuracy threshold.
- 18 . The one or more non-transitory, computer-readable media of claim 17 , wherein the instructions further cause the one or more processors to perform operations comprising, in response to receiving the first results indicating that the first accuracy does not satisfy the accuracy threshold: applying first weights to a first set of predictions generated by a first simple model associated with the first client device and applying second weights to other sets of predictions generated by other simple models of the plurality of simple models, wherein the first weights are higher than the second weights; aggregating the set of predictions generated by each simple model according to the first weights and the second weights to generate an updated plurality of labels for the unlabeled synthetic dataset; inputting, into the new model, the unlabeled synthetic dataset and the updated plurality of labels to update the new model to predict the updated plurality of labels for the unlabeled synthetic dataset; and transmitting an updated copy of the new model to the first client device to cause the updated copy of the new model to deploy on the first client device.
- 19 . The one or more non-transitory, computer-readable media of claim 17 , wherein the instructions further cause the one or more processors to perform operations comprising, in response to receiving the first results indicating that the first accuracy does not satisfy the accuracy threshold: applying first weights to a first set of predictions generated by a first simple model associated with the first client device and applying second weights to other sets of predictions generated by other simple models of the plurality of simple models, wherein the first weights are lower than the second weights; aggregating the set of predictions generated by each simple model according to the first weights and the second weights to generate an updated plurality of labels for the unlabeled synthetic dataset; inputting, into the new model, the unlabeled synthetic dataset and the updated plurality of labels to update the new model to predict the updated plurality of labels for the unlabeled synthetic dataset; and transmitting an updated copy of the new model to the first client device to cause the updated copy of the new model to deploy on the first client device.
- 20 . The one or more non-transitory, computer-readable media of claim 17 , wherein the instructions further cause the one or more processors to perform operations comprising, in response to receiving the first results indicating that the first accuracy does not satisfy the accuracy threshold: determining that first client data from the first client device comprises one or more features that are not included in the unlabeled synthetic dataset; generating an updated unlabeled synthetic dataset comprising the one or more features from the first client data; inputting, into the ensemble model, the updated unlabeled synthetic dataset to cause each simple model to generate an updated set of predictions for the updated unlabeled synthetic dataset based on an updated respective subset of the updated unlabeled synthetic dataset; aggregating the updated set of predictions generated by each simple model to generate an updated plurality of labels for the updated unlabeled synthetic dataset; inputting, into the new model, the updated unlabeled synthetic dataset and the updated plurality of labels to train the new model to predict the updated plurality of labels for the updated unlabeled synthetic dataset; and transmitting an updated copy of the new model to the first client device to cause the updated copy of the new model to deploy on the first client device.
Description
SUMMARY Training machine learning models on client data may expose sensitive data to malicious actors. This risk is especially prevalent in federated learning systems, where models are trained on client data at client devices before being transmitted to a central server. Client data is again exposed when a model trained at the central server on aggregated client data is transmitted back to the client devices. Malicious actors may intercept a trained model while in transmission and ascertain client information from the trained model. In some circumstances, this may allow sensitive data—such as clients' medical or financial information—to be exploited. Thus, a mechanism is desired for protecting client data from malicious actors when transmitting machine learning models trained on client data. Methods and systems are described herein for protecting client data from malicious actors while training machine learning models on client data. A data protection system may be built and configured to perform operations discussed herein. The data protection system may transmit multiple simple models to each client device from a central server. Each simple model may be generated so that once it is transmitted to a client device, it trains on a subset of the client data on that client device. For example, the data protection system may transmit several simple models to a particular client device and each simple model may train on a subset of features of the client data on that client device. By dividing the features across multiple simple models, the data protection system ensures that no single simple model learns a complete picture of the data for any client. The central server then receives the trained simple models from the client devices. Even if one or more of these trained simple models are intercepted during transmission to the central server, the client data is protected by each simple model's limited understanding of the client data. At the central server, the data protection system may input an unlabeled synthetic dataset into an ensemble model. For example, the ensemble model may include the simple models trained at the client devices. The data protection system may cause each simple model to generate predictions for the unlabeled synthetic dataset based on that simple model's limited understanding of the client data on which it was trained. For example, based on the subset of features on which a particular simple model was trained, that simple model may predict a label for the synthetic dataset. The ensemble model may then aggregate the predictions generated by the various simple models to determine labels for the synthetic dataset. For example, the data protection system may label the synthetic dataset based on the predictions generated by the various simple models. The data protection system may then train a new model to predict the labels for the unlabeled synthetic dataset. The data protection system may transmit this new model to the client devices for deployment. The client data is again protected during transmission because the new model has been trained on synthetic data instead of client data. The data protection system thus enables training of machine learning models on client data while protecting the client data from exposure to malicious actors. In particular, the data protection system may transmit, to each client device, a set of untrained simple models. Each untrained simple model may be generated to be trained on a client device using a subset of features of the client data on the client device. For example, the data protection system may transmit several simple models to a particular client device and each simple model may train on a subset of features of the client data on that client device. By dividing the features across multiple simple models, the data protection system ensures that no single simple model learns a complete picture of the data for any client. In some embodiments, the data protection system may receive, from the client devices, the simple models once they have been trained on the client data. Each trained simple model may be trained to generate predictions based on a subset of the features of the client data. For example, each simple model may be trained to predict a label of the client data based on only a subset of the features. The client data is thereby protected by each simple model's limited understanding of the client data. Even if a trained simple model is intercepted during transmission to the central server, a malicious actor cannot ascertain the client data from the model. The data protection system may then retrieve an unlabeled synthetic dataset. In some embodiments, the synthetic dataset may be generated based on the same features as the client data. The data protection system may input, into each trained simple model, a subset of the synthetic dataset. For example, the data protection system may input, into a first simple model trained on the first two feat