Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cross validation with early stopping, dynamic eval_set #2

Open
c60evaporator opened this issue May 3, 2022 · 0 comments
Open

Cross validation with early stopping, dynamic eval_set #2

c60evaporator opened this issue May 3, 2022 · 0 comments

Comments

@c60evaporator
Copy link
Owner

c60evaporator commented May 3, 2022

背景

There are a lot of demands that using early stopping in cross validation with dynamic eval_set.

各種Scikit-Learn APIにおいて、クロスバリデーション実施時にテストデータをeval_setに動的に渡したいという要望が各種Issuesで提起されている。
→実現すればパラメータチューニングにおけるearly_stoppingの実装が楽になる。

各種ライブラリにおける上記機能への対応状況を、本Issuesにまとめる

LightGBM

Training APIにおいては、上記機能(クロスバリデーション時にテストデータを eval_setに動的に渡す)をlightgbm.cv()メソッドで実現済

Scikit-Learn APIにおいても同様の要望が出ているが、まだ実施されていない模様
microsoft/LightGBM#3313

実施の障害となっているのが、Scikit-Learn APIのfit()メソッドにX, y, sample_weight引数を渡さないというルールへの対応。こちらへの対応が終わらないとScikit-Learn APIの各種改良が難しい
microsoft/LightGBM#2966 (comment)

上記機能とは若干ずれるが、パイプライン用変数をeval_setに渡した時、transformerが適用されていないというIssuesもあり
microsoft/LightGBM#5090

XGBoost

同様の要望が以下のIssuesで提起(ただしLightGBMと同様にfit()メソッドへ余分な引数が渡されている問題を先にクリアする必要あり)
dmlc/xgboost#7782

本ライブラリでの対応状況

seaborn_analyzer.cross_val_score_eval_setで暫定対応。
以下のように実装することで、クロスバリデーション実施時にテストデータをeval_setに渡せる

from seaborn_analyzer import cross_val_score_eval_set
# 使用するパラメータ
param = {'objective': 'regression',  # 最小化させるべき損失関数
        'random_state': 42,  # 乱数シード
        'boosting_type': 'gbdt',  # boosting_type
        'n_estimators': 10000  # 最大学習サイクル数。early_stopping使用時は大きな値を入力
        }
verbose_eval = 0  # この数字を1にすると学習時のスコア推移がコマンドライン表示される
# early_stoppingを指定してLightGBM学習
lgbr = lgb.LGBMRegressor(**param)
# クロスバリデーション内部で`fit()`メソッドに渡すパラメータ
fit_params = {'eval_metric':'rmse',
              'eval_set':[(X, y)],
              'early_stopping_rounds': 10,
              'verbose': verbose_eval}
# クロスバリデーション実行
scores = cross_val_score_eval_set(
        eval_set_selection='test',  # 'test'と指定するとテストデータを'eval_set'に渡せる
        estimator=lgbr,  # 学習器
        X=X, y=y,  # クロスバリデーション分割前のデータを渡す
        scoring='neg_root_mean_squared_error',  # RMSE(の逆数)を指定
        cv=cv, verbose=verbose_eval, fit_params=fit_params
        )

本ライブラリ(tune-easy)においては、クラス初期化時にeval_data_source引数='test'を渡すことで、クロスバリデーション実施時にテストデータをeval_setに渡せる。

当面はこの方法を使用するが、LightGBMおよびXGBoostのIssuesに動きがあり次第対応を考える
(本来は本家Scikit-Learnのcross_val_scoreメソッドをeval_setに対応するよう改良できればベストだが、APIの統一を図るハードルが非常に高そう)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant