CN-119106725-B - Trace prediction model training method and device based on meta learning
Abstract
The application provides a track prediction model training method and device based on meta learning, the method comprises the steps of obtaining track data of at least one target scene, constructing meta tasks, dividing the track data into a plurality of domain offset groups, calling a track prediction model to sequentially execute meta tasks in the domain offset groups according to training data of the meta tasks in any one current training round in the training process, obtaining target loss functions for executing the meta tasks, carrying out gradient propagation in the domain offset groups, updating to obtain domain offset parameters, and updating original model parameters of the track prediction model in the current training round according to the domain offset parameters of each domain offset group to obtain final model parameters. The method and the device solve the technical problems that in the prior art, the track prediction model still has excessive dependence on specific source domain data used in training in actual application, the generalization capability across scenes is poor, and the prediction performance is reduced when new target domain data which is not encountered are predicted.
Inventors
- FAN ZIDE
- HUANG FEILONG
- JIANG LIZHENG
- LI XIAOHE
- MOU FANGLI
- GENG YING
- DENG YAWEN
- ZHU KEQING
- JIANG CHUANAO
Assignees
- 中国科学院空天信息创新研究院
Dates
- Publication Date
- 20260508
- Application Date
- 20240813
Claims (8)
- 1. The track prediction model training method based on meta learning is characterized in that the track prediction model training process is performed according to training round iteration, and the method comprises the following steps: The track prediction method comprises the steps of obtaining track data of at least one target scene, wherein the track data comprises virtual source domain data or virtual target domain data, a track prediction model is used for predicting possible track changes of the target scene in the future according to track changes of some targets in the current target scene, and five track tags are arranged on the track data and comprise position tags, time tags, weather tags, traffic tags and track type tags used for describing the target scene; Constructing meta-tasks according to the track data, and dividing the meta-tasks into a plurality of domain offset groups, wherein each domain offset group comprises at least two meta-tasks, and training data of the meta-tasks comprise virtual source domain data for meta-training and virtual target domain data for meta-testing; In any one current training round in the training process, aiming at each domain offset group, based on training data of meta-tasks, invoking a track prediction model to sequentially execute the meta-tasks in the domain offset group to obtain a target loss function for executing the meta-tasks, wherein the target loss function is constructed according to a source domain loss value, a target domain loss value and a cross-scene attention alignment loss value; Gradient propagation is carried out in the domain offset groups according to the target loss function, the domain offset parameters of the track prediction model are updated to obtain, and the original model parameters of the track prediction model in the current training round are updated according to the domain offset parameters of each domain offset group to obtain final model parameters; The training data based on the meta-task, calling a track prediction model to sequentially execute the meta-task in the domain offset group to obtain a target loss function for executing the meta-task, comprises the following steps: aiming at each meta task, invoking virtual source domain data to perform meta training on the track prediction model, and updating meta learning original parameters of the track prediction model according to a source domain loss value preset in the meta training process to obtain meta learning temporary parameters; Performing meta-test on the meta-learning temporary parameters of the track prediction model through virtual target domain data to obtain a target domain loss value in a meta-test process; Determining a cross-scene attention alignment loss value of the track prediction model in the meta-training process and the meta-testing process; constructing a target loss function for executing a meta-task according to the source domain loss value, the target domain loss value and the cross-scene attention alignment loss value; The determining a cross-scene attention alignment loss value of the trajectory prediction model during the meta-training process and the meta-testing process includes: determining a training query vector and a training key vector which are obtained when the first data sample trains the track prediction model aiming at each first data sample in virtual source domain data in the meta training process, and determining the average value of the training query vector and the average value of the training key vector; Determining a test query vector and a test key vector obtained when the second data sample tests the track prediction model aiming at each second data sample in the virtual target domain data in the meta-test process, and determining the average value of the test query vector and the average value of the test key vector; Determining a first difference value between the average value of the training query vector and the average value of the test query vector and a first normal form of the first difference value, and determining a second difference value between the average value of the training key vector and the average value of the test key vector and a second normal form of the second difference value; And determining the sum of the first model and the second model, and constructing a cross-scene attention alignment loss function according to the sum.
- 2. The meta-learning-based trajectory prediction model training method of claim 1, wherein the number of trajectory labels of the virtual source domain data is the same as the number of trajectory labels of the virtual target domain data, and the trajectory label value of the virtual source domain data is different from the trajectory label value of the virtual target domain data.
- 3. The method for training a trajectory prediction model based on meta learning according to claim 1, wherein performing gradient propagation in a domain offset group according to the objective loss function, updating domain offset parameters of the trajectory prediction model, comprises: determining a task execution sequence of the track prediction model for executing meta-tasks in a domain offset group; Sequentially updating element learning original parameters of the track prediction model in a domain offset group through the target loss function according to the task execution sequence; And after the track prediction model executes the last meta-task in the domain offset group, updating the meta-learning original parameters of the last meta-task to obtain parameters serving as domain offset parameters of the domain offset group.
- 4. The method for training a trajectory prediction model based on meta learning according to claim 1, wherein updating original model parameters of the trajectory prediction model in a current training round according to domain offset parameters of each domain offset group to obtain final model parameters comprises: determining a parameter sum of domain offset parameters in each domain offset group, and determining an offset difference value between the parameter sum and the original model parameter; weighting the average value of the offset difference values through parallel learning rate, and adding the weighted result and the original model parameters to be used as model updating parameters of the next training round; And taking the model updating parameters obtained by updating in the last training round in the training process as final model parameters.
- 5. A training device for a trajectory prediction model based on meta learning, the device comprising: The system comprises an acquisition module, a track prediction model, a track prediction module and a track prediction module, wherein the acquisition module is used for acquiring track data of at least one target scene, the track data comprises virtual source domain data or virtual target domain data, the track prediction model is used for predicting possible track changes of the target scene in the future according to track changes of some targets in the current target scene, and the track data is provided with five track tags, wherein the five track tags comprise position tags, time tags, weather tags, traffic tags and track type tags used for describing the target scene; The division module is used for constructing meta-tasks according to the track data and dividing the meta-tasks into a plurality of domain offset groups, wherein each domain offset group is provided with at least two meta-tasks, and training data of the meta-tasks comprise virtual source domain data for meta-training and virtual target domain data for meta-testing; The execution module is used for sequentially executing the meta-tasks in the domain offset groups by calling the track prediction model according to training data of the meta-tasks in each domain offset group in any current training round in the training process to obtain a target loss function for executing the meta-tasks, wherein the target loss function is constructed according to a source domain loss value, a target domain loss value and a cross-scene attention alignment loss value; The updating module is used for carrying out gradient propagation in the domain offset groups according to the target loss function, updating domain offset parameters of the track prediction model to obtain domain offset parameters of the track prediction model, and updating original model parameters of the track prediction model in the current training round according to the domain offset parameters of each domain offset group to obtain final model parameters; The training data based on the meta-task, calling a track prediction model to sequentially execute the meta-task in the domain offset group to obtain a target loss function for executing the meta-task, comprises the following steps: aiming at each meta task, invoking virtual source domain data to perform meta training on the track prediction model, and updating meta learning original parameters of the track prediction model according to a source domain loss value preset in the meta training process to obtain meta learning temporary parameters; Performing meta-test on the meta-learning temporary parameters of the track prediction model through virtual target domain data to obtain a target domain loss value in a meta-test process; Determining a cross-scene attention alignment loss value of the track prediction model in the meta-training process and the meta-testing process; constructing a target loss function for executing a meta-task according to the source domain loss value, the target domain loss value and the cross-scene attention alignment loss value; The determining a cross-scene attention alignment loss value of the trajectory prediction model during the meta-training process and the meta-testing process includes: determining a training query vector and a training key vector which are obtained when the first data sample trains the track prediction model aiming at each first data sample in virtual source domain data in the meta training process, and determining the average value of the training query vector and the average value of the training key vector; Determining a test query vector and a test key vector obtained when the second data sample tests the track prediction model aiming at each second data sample in the virtual target domain data in the meta-test process, and determining the average value of the test query vector and the average value of the test key vector; Determining a first difference value between the average value of the training query vector and the average value of the test query vector and a first normal form of the first difference value, and determining a second difference value between the average value of the training key vector and the average value of the test key vector and a second normal form of the second difference value; And determining the sum of the first model and the second model, and constructing a cross-scene attention alignment loss function according to the sum.
- 6. An electronic device comprising a memory, a processor and a computer program stored on the memory and executable on the processor, wherein the processor implements a method of training a meta-learning based trajectory prediction model as claimed in any one of claims 1 to 4 when the computer program is executed by the processor.
- 7. A non-transitory computer readable storage medium having stored thereon a computer program, wherein the computer program when executed by a processor implements the meta-learning based trajectory prediction model training method according to any one of claims 1 to 4.
- 8. A computer program product comprising a computer program, characterized in that the computer program, when executed by a processor, implements a method of training a trajectory prediction model based on meta learning as claimed in any one of claims 1 to 4.
Description
Trace prediction model training method and device based on meta learning Technical Field The invention relates to the technical field of automatic driving, in particular to a trajectory prediction model training method and device based on meta learning. Background Trajectory prediction techniques generally play a critical midstream role in autopilot systems, with predictive performance directly dependent on the accuracy of upstream target detection and tracking. In practical applications, track prediction for vehicles, pedestrians and other various targets is generally limited by the performance of sensor devices, changeable environmental conditions and the difference of prediction algorithms, and deviation of characteristics and data distribution often occurs in the upstream process. These deviations not only affect the quality of the data, but can also cause serious domain shifting problems in downstream tasks. Aiming at the domain offset problem of the track prediction task, the prior art generally adopts some domain generalization methods, trains through data on a plurality of source domains, and aims to learn general feature representations applicable to all source domains so as to train a prediction model to predict unknown target domain data. For example, with a meta-learning-based method, domain generalization is attempted by simulating domain offset in a real-world situation, but since definition of meta-tasks is generally not well defined, there is still possibility that these models have excessive dependence on specific source domain data used in training in practical applications, and generalization capability across scenes is poor, so that prediction performance is degraded when new target domain data which is not encountered is predicted. Disclosure of Invention The invention provides a track prediction model training method and device based on meta-learning, which are used for solving the technical problems that in the prior art, the track prediction model based on meta-learning still possibly has excessive dependence on specific source domain data used in training and has poor generalization capability across scenes, so that the prediction performance is reduced when new target domain data which is not encountered are predicted. The invention provides a track prediction model training method based on meta learning, which comprises the following steps: acquiring track data of at least one target scene, wherein the track data comprises virtual source domain data or virtual target domain data; Constructing meta-tasks according to the track data, and dividing the meta-tasks into a plurality of domain offset groups, wherein each domain offset group comprises at least two meta-tasks, and training data of the meta-tasks comprise virtual source domain data for meta-training and virtual target domain data for meta-testing; In any one current training round in the training process, aiming at each domain offset group, based on training data of meta-tasks, invoking a track prediction model to sequentially execute the meta-tasks in the domain offset group to obtain a target loss function for executing the meta-tasks, wherein the target loss function is constructed according to a source domain loss value, a target domain loss value and a cross-scene attention alignment loss value; And carrying out gradient propagation in the domain offset groups according to the target loss function, updating to obtain domain offset parameters of the track prediction model, and updating original model parameters of the track prediction model in the current training round according to the domain offset parameters of each domain offset group to obtain final model parameters. In some embodiments, the track data is provided with five track tags, where the five track tags include a position tag, a time tag, a weather tag, a traffic tag, and a track type tag for describing a target scene, the track tag number of the virtual source domain data is the same as the track tag number of the virtual target domain data, and the track tag value of the virtual source domain data is different from the track tag value of the virtual target domain data. In some embodiments, the training data of the meta-task, calling the track prediction model to sequentially execute the meta-task in the domain offset group, to obtain a target loss function for executing the meta-task, includes: aiming at each meta task, invoking virtual source domain data to perform meta training on the track prediction model, and updating meta learning original parameters of the track prediction model according to a source domain loss value preset in the meta training process to obtain meta learning temporary parameters; Performing meta-test on the meta-learning temporary parameters of the track prediction model through virtual target domain data to obtain a target domain loss value in a meta-test process; Determining a cross-scene attention alignment loss value of the tr