CN-122021762-A - RWKV model training method, RWKV model prediction method and related devices
Abstract
The application provides a training method, a prediction method and a related device of RWKV models. The method comprises the steps of sequentially splitting a long sequence sample into a plurality of sample groups, training the RWKV model according to the plurality of sample groups, taking a final recursion state vector of a previous sample group as an initial recursion state vector of a next sample group according to the sequence of the plurality of sample groups, and executing part of training processes of at least part of sample groups in parallel. The training efficiency of RWKV models can be improved to a certain extent.
Inventors
- JIANG ZHI
Assignees
- 阿里健康科技(中国)有限公司
Dates
- Publication Date
- 20260512
- Application Date
- 20251226
Claims (10)
- 1. A method of training a RWKV model, comprising: Sequentially splitting a long sequence of samples into a plurality of sample groups, wherein each sample group comprises a plurality of samples; Training the RWKV models according to the plurality of sample groups, wherein the final recursive state vector of the previous sample group is used as the initial recursive state vector of the next sample group according to the sequence of the plurality of sample groups, and the partial training process of at least part of the sample groups is executed in parallel.
- 2. The method of claim 1, wherein training the RWKV model from the plurality of sample sets comprises: each sample set is input to one GPU separately.
- 3. The method of claim 2, wherein training the RWKV model from the plurality of sample sets further comprises: and obtaining forward projection operation tasks of corresponding sample groups by parallel operation of the plurality of GPUs.
- 4. A method according to claim 3, wherein the forward projection operation task of the respective sample group is derived by parallel operation of the plurality of GPUs, comprising: And executing a first projection operation task, a second projection operation task and a third projection operation task of the corresponding sample group by each GPU in parallel, wherein the first projection operation task is used for generating gating vectors of samples in the sample group, the second projection operation task is used for generating key vectors of the samples in the sample group, and the third projection operation task is used for generating value vectors of the samples in the sample group.
- 5. The method of claim 4, wherein the plurality of sample groups comprises a first sample group and a second sample group that are sequentially adjacent, the first sample group corresponding to a first GPU and the second sample group corresponding to a second GPU, the method further comprising: Under the condition that the first GPU executes a first recursive state vector and a time mixed output vector generated by recursive operation according to the first sample group, the first recursive state vector is used as an initial recursive state vector for executing the recursive operation of the second sample group by the second GPU; and in the process that the first GPU executes feedforward network operation according to the time mixed output vector, the second GPU executes recursive operation according to the second sample group to generate a second recursive state vector.
- 6. The method according to claim 1, wherein the method further comprises: Acquiring historical advertisement display information, wherein the historical advertisement display information comprises display offer information, advertisement information and clicking results; and generating advertisement vector representation according to the time sequence of the historical advertisement display information, corresponding to the advertisement information and the corresponding display offer information, wherein the advertisement vector representation and the corresponding clicking result form samples, and the samples are arranged according to the time sequence to form the long-sequence samples.
- 7. A click rate prediction method, comprising: the RWKV model obtained according to the training method as claimed in any one of the preceding claims 1 to 5 predicts click rate for the entered presentation offer information and advertisement information.
- 8. A computer readable storage medium, having stored thereon a computer program which, when executed by a processor, causes the processor to implement the method of any of claims 1 to 7.
- 9. A computer device comprising a memory and a processor, wherein the memory stores at least one computer program, the at least one computer program being loaded and executed by the processor to implement the method of any one of claims 1 to 7.
- 10. A computer program product comprising computer instructions which, when executed by a processor, implement the method of any of claims 1 to 7.
Description
RWKV model training method, RWKV model prediction method and related devices Technical Field One or more embodiments of the present application relate to the field of artificial intelligence, and in particular, to a training method, a prediction method, and a related apparatus for RWKV models. Background In RWKV (RECEPTANCE WEIGHTED KEY Value) model training, related art to enhance the modeling ability of RWKV models on long sequences (e.g., user behavior sequences, medical consultation dialogue sequences, or advertisement exposure sequences, etc.), it is often necessary to input long sequence samples during the training phase so that the RWKV model can learn the dependency across time steps. However, the sequence dependent computation of the RWKV model is typically updated step by step between time steps based on the recursive state vector, making the training of the RWKV model less efficient. Disclosure of Invention In view of this, one or more embodiments of the present application provide a training method, a prediction method and a related apparatus for a REKV model, which can improve the training efficiency of a RWKV model to a certain extent. In a first aspect, one or more embodiments of the present application provide a training method of RWKV models, which includes sequentially splitting a long-sequence sample into a plurality of sample groups, where each sample group includes a plurality of samples, training the RWKV models according to the plurality of sample groups, where, according to the order of the plurality of sample groups, a final recursive state vector of a previous sample group is used as an initial recursive state vector of a next sample group, and partial training processes of at least part of sample groups are performed in parallel. In a second aspect, one or more embodiments of the present application provide a click rate prediction method, including predicting a click rate for input presentation offer information and advertisement information according to RWKV models obtained by the training method as described above. In a third aspect, one or more embodiments of the present application provide a training apparatus for a RWKV model, which includes a splitting module configured to split a long-sequence sample into a plurality of sample groups sequentially, where each sample group includes a plurality of samples, and a training module configured to train the RWKV model according to the plurality of sample groups, where, in order of the plurality of sample groups, a final recursive state vector of a previous sample group is used as an initial recursive state vector of an adjacent next sample group, and partial training processes of at least some sample groups are performed in parallel. In a fourth aspect, one or more embodiments of the present application provide a click rate prediction apparatus for predicting a click rate for input presentation offer information and advertisement information according to RWKV models obtained by the training method as described above. In a fifth aspect, one or more embodiments of the present application provide a computer device, the computer device including a memory and a processor, the memory storing at least one computer program, the at least one computer program being loaded and executed by the processor to implement a method as described above. In a sixth aspect, one or more embodiments of the application provide a computer program product comprising computer instructions which, when executed by a processor, implement a method as described above. In a seventh aspect, one or more embodiments of the present application provide a computer-readable storage medium having stored thereon a computer program which, when executed by a processor, causes the processor to implement a method as described above. As can be seen from the foregoing embodiments, in the embodiments of the present application, by sequentially splitting a long-sequence sample into a plurality of sample groups, and using a final recursive state vector of a previous sample group as an initial recursive state vector of a next sample group on the premise of maintaining a sequential relationship of the sample groups, performing grouping training on a RWKV model, and simultaneously enabling a partial training process of at least part of the sample groups to be performed in parallel, splitting and parallelizing processing on the long-sequence sample training process under the condition that the recursive state continuity of the RWKV model is not damaged is achieved, the length of a single training sequence is reduced, and the training parallelism is improved, thereby improving the training efficiency of the RWKV model to a certain extent. Drawings FIG. 1 is a schematic diagram of a training process module of RWKV model according to one embodiment of the present application. Fig. 2 is a schematic block diagram of training tasks in a RWKV model training process according to an embodiment of