Model Agnostic Sample Reweighting for Out-of-Distribution Learning

Xiao Zhou, Yong Lin, Renjie Pi, Weizhong Zhang, Renzhe Xu, Peng Cui, Tong Zhang

Research output: Contribution to journalConference articlepeer-review

Abstract

Distributionally robust optimization (DRO) and invariant risk minimization (IRM) are two popular methods proposed to improve out-of-distribution (OOD) generalization performance of machine learning models. While effective for small models, it has been observed that these methods can be vulnerable to overfitting with large overparameterized models. This work proposes a principled method, Model Agnostic samPLe rEweighting (MAPLE), to effectively address OOD problem, especially in overparameterized scenarios. Our key idea is to find an effective reweighting of the training samples so that the standard empirical risk minimization training of a large model on the weighted training data leads to superior OOD generalization performance. The overfitting issue is addressed by considering a bilevel formulation to search for the sample reweighting, in which the generalization complexity depends on the search space of sample weights instead of the model size. We present theoretical analysis in linear case to prove the insensitivity of MAPLE to model size, and empirically verify its superiority in surpassing state-of-the-art methods by a large margin. Code is available at https://github.com/x-zho14/MAPLE.

Original languageEnglish (US)
Pages (from-to)27203-27221
Number of pages19
JournalProceedings of Machine Learning Research
Volume162
StatePublished - 2022
Externally publishedYes
Event39th International Conference on Machine Learning, ICML 2022 - Baltimore, United States
Duration: Jul 17 2022Jul 23 2022

ASJC Scopus subject areas

  • Artificial Intelligence
  • Software
  • Control and Systems Engineering
  • Statistics and Probability

Fingerprint

Dive into the research topics of 'Model Agnostic Sample Reweighting for Out-of-Distribution Learning'. Together they form a unique fingerprint.

Cite this