S^2FT: Efficient, Scalable and Generalizable LLM Fine-tuning by Structured Sparsity
Xinyu Yang, Jixuan Leng, Geyang Guo, Jiawei Zhao, Ryumei Nakada, Linjun Zhang, Huaxiu Yao, and Beidi Chen
In NeurIPS , 2024
Current PEFT methods for LLMs can achieve either high quality, efficient training, or scalable serving, but not all three simultaneously. To address this limitation, we investigate sparse fine-tuning and observe a remarkable improvement in generalization ability. Utilizing this key insight, we propose a family of Structured Sparse Fine-Tuning (SFT) methods for LLMs, which concurrently achieve state-of-the-art fine-tuning performance, training efficiency, and inference scalability. SFT accomplishes this by selecting sparsely and computing densely. It selects a few heads and channels in the MHA and FFN modules for each Transformer Block, respectively. Next, it co-permutes weight matrices on both sides of the coupled structures in LLMs to connect the selected components in each layer into a dense submatrix. Finally, \model performs in-place gradient updates on all submatrices. Through theoretical analysis and empirical results, our method prevents overfitting and forgetting, delivers SOTA performance on both commonsense and arithmetic reasoning with 4.6% and 1.3% average improvements compared to LoRA, and outperforms full FT by 11.5% when generalize to various domains after instruction tuning. By integrating our partial back-propagation algorithm, model saves the fine-tuning memory up to 3 and improves the latency by 1.5-2.7 compared to full FT, while delivering an average 10% improvement over LoRA on both metrics. We further demonstrate that SFT can be decoupled into adapters, enabling effective fusion, fast switch, and efficient parallelism for serving multiple fine-tuned models.