Source code for czsc.utils.optuna

# -*- coding: utf-8 -*-
"""
author: zengbin93
email: zeng_bin8888@163.com
create_dt: 2024/3/21 13:56
describe: optuna 工具函数
"""
import hashlib
import optuna
import inspect
import pandas as pd


[docs]def optuna_study(objective, direction="maximize", n_trials=100, **kwargs): """使用optuna进行参数优化""" objective_code = inspect.getsource(objective) study_name = hashlib.md5(f"{objective_code}_{direction}".encode("utf-8")).hexdigest().upper()[:12] study = optuna.create_study(direction=direction, study_name=study_name) timeout = kwargs.pop("timeout", None) n_jobs = kwargs.pop("n_jobs", 1) study.optimize(objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs, **kwargs) return study
[docs]def optuna_good_params(study: optuna.Study, keep=0.2) -> pd.DataFrame: """获取optuna优化结果中的最优参数 :param study: optuna.study.Study :param keep: float, 保留最优参数的比例, 默认0.2 如果keep小于0,则按比例保留;如果keep大于0,则保留keep个参数组 :return: pd.DataFrame, 最优参数组列表 """ assert keep > 0, "keep必须大于0" params = [] for trail in study.trials: if trail.state != optuna.trial.TrialState.COMPLETE: continue if trail.value is None: continue p = {"params": trail.params, "objective": trail.value} params.append(p) n = int(len(params) * keep) if keep < 1 else int(keep) reverse = study.direction == 2 params = sorted(params, key=lambda x: x['objective'], reverse=reverse) dfp = pd.DataFrame(params[:n]) dfp = dfp.drop_duplicates(subset=['params'], keep='first').reset_index(drop=True) return dfp