Search

US-12626146-B2 - Data pruning in tree-based fitted Q iteration

US12626146B2US 12626146 B2US12626146 B2US 12626146B2US-12626146-B2

Abstract

A computer-implemented method is provided for data reduction in a memory device for machine learning. The method includes storing, in the memory device, data that has been used for training in a tree-based fitted Q iteration session which learns an action value function with an ensemble of decision trees from the data. The method further includes determining, by a processor device, samples to be removed from the data based on a number of samples which belong to leaf nodes of the decision trees. The method also includes removing, from the memory device, the determined samples from the data to reduce an amount of the data. The method additionally includes learning, by the processor device, a new ensemble of decision trees using the data from which the determined samples have been removed together with new data.

Inventors

  • Takayuki Osogami
  • RYO IWAKI
  • Kohei Miyaguchi

Assignees

  • INTERNATIONAL BUSINESS MACHINES CORPORATION

Dates

Publication Date
20260512
Application Date
20210304

Claims (20)

  1. 1 . A computer-implemented method for data reduction in a memory device for machine learning, comprising: storing, in the memory device, data that has been used for training in a tree-based fitted Q iteration session which learns an action value function with an ensemble of decision trees from the data; determining, by a processor device, samples to be removed from the data based on a number of samples which belong to leaf nodes of the decision trees having identical positions within different decision trees from the ensemble of decision trees based on sample similarity to obtain determined samples; removing, from the memory device, the determined samples from the data to reduce an amount of the data based on determined leaf node statistics for sample similarity by using the ensemble of decision trees as a regressor to obtain pruned data; and learning, by the processor device, a new ensemble of decision trees using the pruned data together with new data.
  2. 2 . The computer-implemented method of claim 1 , wherein the samples which belong to a leaf node are removed from the memory device, when a number of the samples which belong to the leaf node exceeds a pruning threshold based on sample similarity.
  3. 3 . The computer-implemented method of claim 1 , wherein said storing, determining, removing, and learning steps are repeated recursively to efficiently learn a true action-value function representing the pruned data and the new data.
  4. 4 . The computer-implemented method of claim 1 , wherein each of the samples is removed with a probability that is determined responsive to an average number of the samples, where the average number of samples is taken over all leaf nodes which a given sample belongs to.
  5. 5 . The computer-implemented method of claim 1 , wherein the number of samples used to determine whether to remove data comprises an average number of samples that belong to a given leaf node.
  6. 6 . The computer-implemented method of claim 1 , wherein the number of samples used to determine whether to remove data comprises a minimum number of samples that belong to a given leaf node.
  7. 7 . The computer-implemented method of claim 1 , wherein the number of samples used to determine whether to remove data comprises a maximum number of samples that belong to a given leaf node.
  8. 8 . The computer-implemented method of claim 1 , wherein the number of samples used to determine whether to remove data comprises a median of samples that belong to a given leaf node.
  9. 9 . The computer-implemented method of claim 1 , wherein the machine learning comprises a gradient tree boosting process.
  10. 10 . The computer-implemented method of claim 1 , further comprising generating a prediction which controls a motion of a motor vehicle using remaining data in the memory device.
  11. 11 . The computer-implemented method of claim 1 , wherein the fitted Q iteration session computes, from a set of four-tuples, an approximation of an optimal stationary policy, and wherein the set of four-tuples comprises a state at time t, an action at time t, a reward at time t, and a discount factor γ.
  12. 12 . A computer program product for data reduction in a memory device for machine learning, the computer program product comprising a non-transitory computer readable storage medium having program instructions embodied therewith, the program instructions executable by a computer to cause the computer to perform a method comprising: storing, in the memory device, data that has been used for training in a tree-based fitted Q iteration session which learns an action value function with an ensemble of decision trees from the data; determining, by a processor device, samples to be removed from the data based on a number of samples which belong to leaf nodes of the decision trees having identical positions within different decision trees from the ensemble of decision trees based on sample similarity to obtain determined samples; removing, from the memory device, the determined samples from the data to reduce an amount of the data based on determined leaf node statistics for sample similarity by using the ensemble of decision trees as a regressor to obtain pruned data; and learning, by the processor device, a new ensemble of decision trees using the pruned data together with new data.
  13. 13 . The computer program product of claim 12 , wherein the samples which belong to a leaf node are removed from the memory device, when a number of the samples which belong to the leaf node exceeds a pruning threshold based on sample similarity.
  14. 14 . The computer program product of claim 12 , wherein said storing, determining, removing, and learning steps are repeated recursively to efficiently learn a true action-value function representing the pruned data and the new data.
  15. 15 . The computer program product of claim 12 , wherein each of the samples is removed with a probability that is determined responsive to an average number of the samples, where the average number of samples is taken over all leaf nodes which a given sample belongs to.
  16. 16 . The computer program product of claim 12 , wherein the number of samples used to determine whether to remove data comprises an average number of samples that belong to a given leaf node.
  17. 17 . The computer program product of claim 12 , wherein the number of samples used to determine whether to remove data comprises a minimum number of samples that belong to a given leaf node.
  18. 18 . The computer program product of claim 12 , wherein the number of samples used to determine whether to remove data comprises a maximum number of samples that belong to a given leaf node.
  19. 19 . The computer program product of claim 12 , wherein the number of samples used to determine whether to remove data comprises a median of samples that belong to a given leaf node.
  20. 20 . A computer processing system for data reduction in a memory device for machine learning, comprising: a memory device configured to store program code; a processor device operatively coupled to the memory device for running the program code to store, in the memory device, data that has been used for training in a tree-based fitted Q iteration session which learns an action value function with an ensemble of decision trees from the data; determine samples to be removed from the data based on a number of samples which belong to leaf nodes of the decision trees having identical positions within different decision trees from the ensemble of decision trees based on sample similarity to obtain determined samples; remove, from the memory device, the determined samples from the data to reduce an amount of the data based on leaf node statistics for sample similarity by using the ensemble of decision trees as a regressor to obtain pruned data; and learn a new ensemble of decision trees using the pruned data together with new data.

Description

BACKGROUND The present invention generally relates to machine learning, and more particularly to data pruning in tree-based fitted Q iteration. In applications of Reinforcement Learning (RL) to industrial problems (robotics, healthcare, and so forth), one often uses the data that has been collected in advance (batch RL). The additional data collected by RL agents can also be used to retrain the agents but very infrequently (semi-batch RL), only after a large amount of additional data is collected (because of the high cost to test before deployment). Reinforcement learning aims to determine an optimal control policy from interaction with a system or from observations gathered from a system. In batch mode, it can be achieved by approximating the so-called Q-function based on a set of four-tuples (xt, ut, rt, xt+1) where xt denotes the system state at time t, ut the control action taken, rt the instantaneous reward obtained and xt+1 the successor state of the system, and by determining the control policy from this Q-function. The Q-function approximation may be obtained from the limit of a sequence of (batch mode) supervised learning problems. In such batch or semi-batch RL for industrial applications, Tree-based Fitted Q Iteration is known to outperform other techniques. The policies found with Tree-based Fitted Q Iteration are known to have the highest performance with ensemble methods (extremely or totally randomized trees, gradient tree boosting), which however require increased computational cost particularly at the time of training. Even if the time for a single run of training is acceptable, Tree-based Fitted Q Iteration has several key hyperparameters and significantly benefits from a grid search which requires many runs of training. SUMMARY According to aspects of the present invention, a computer-implemented method is provided for data reduction in a memory device for machine learning. The method includes storing, in the memory device, data that has been used for training in a tree-based fitted Q iteration session which learns an action value function with an ensemble of decision trees from the data. The method further includes determining, by a processor device, samples to be removed from the data based on a number of samples which belong to leaf nodes of the decision trees. The method also includes removing, from the memory device, the determined samples from the data to reduce an amount of the data. The method additionally includes learning, by the processor device, a new ensemble of decision trees using the data from which the determined samples have been removed together with new data. According to other aspects of the present invention, a computer program product is provided for data reduction in a memory device for machine learning. The computer program product includes a non-transitory computer readable storage medium having program instructions embodied therewith. The program instructions are executable by a computer to cause the computer to perform a method. The method includes storing, in the memory device, data that has been used for training in a tree-based fitted Q iteration session which learns an action value function with an ensemble of decision trees from the data. The method further includes determining, by a processor device, samples to be removed from the data based on a number of samples which belong to leaf nodes of the decision trees. The method also includes removing, from the memory device, the determined samples from the data to reduce an amount of the data. The method additionally includes learning, by the processor device, a new ensemble of decision trees using the data from which the determined samples have been removed together with new data. According to yet other aspects of the present invention, a computer processing system is provided for data reduction in a memory device for machine learning. The system includes a memory device configured to store program code. The system further includes a processor device operatively coupled to the memory device for running the program code to store, in the memory device, data that has been used for training in a tree-based fitted Q iteration session which learns an action value function with an ensemble of decision trees from the data. The processor further runs the program code to determine samples to be removed from the data based on a number of samples which belong to leaf nodes of the decision trees. The processor also runs the program code to remove, from the memory device, the determined samples from the data to reduce an amount of the data. The processor additionally runs the program code to learn a new ensemble of decision trees using the data from which the determined samples have been removed together with new data. These and other features and advantages will become apparent from the following detailed description of illustrative embodiments thereof, which is to be read in connection with the accompanying drawings. BR