在小gpu上训练bart

Posted by Weihao Zeng on August 7, 2021

2021-08-07-在小GPU上训练BART

https://github.com/pytorch/fairseq/issues/1413

在利用fairseq框架复现BART模型在CNN/DM训练过程中,对于显存较小的GPU, 如果采用降低MAX_TOKENS512的方法,可能会遇到的问题是

AssertionError: sentence at index 227550 of size 728 exceeds max_tokens limit of 512!

如果将MAX_TOKENS设置到1024, 则会遇到CUDA out of memory报错。

解决方法1:(没有测试过)

max_tokens设置为800的情况下,将设置为--max-target-positions 512 --max-source-positions 512(这会过滤掉长度超过512的sample)

解决方法2:(亲测有效)

添加--memory-efficient-fp16, 需要主要该选项会降低模型表现,尤其是在batch sizes较小的情况下。