Errors During Training for the Original Implementation and the Fixes for the Errors
#24
by
v2ray
- opened
https://huggingface.co/v2ray/dbrx-base-fixed
The original DBRX implementation code has a few bugs which only affect training, which I fixed in my re-upload.
I re-uploaded because the changes require the weights files to be converted, so if anyone want to use the fix you need to re-download the entire weights!
The issues - How I fixed them:
- Error when using gradient checkpointing - Fixed by using positional arguments instead because
_gradient_checkpointing_func
doesn't support kwargs. - VRAM usage go zoom and
CUDA Out of Memory
when backpropping through the MLP layer - Fixed by separating the experts' weights into different tensors instead of using a single tensor for all the experts. IDK why this fixed it but maybe it's because torch is trying to compute gradient for every expert at once, which shouldn't happen since it's a MoE model.
cc: @daking and @abhi-mosaic
Error 1. should be fixed in this PR.
We are currently working on fixing Error 2 in the same PR. For a current workaround please see: huggingface.co/databricks/dbrx-instruct/discussions/10#660566f14f41c0c7c0e54ab9