当前位置: 首页 > news >正文

python自动搜索最佳超参数之GridSearchCV函数

介绍

当我们跑机器学习程序时,尤其是调节网络参数时,通常待调节的参数有很多,参数之间的组合更是繁复。依照注意力>时间>金钱的原则,人力手动调节注意力成本太高,非常不值得。For循环或类似于for循环的方法受限于太过分明的层次,不够简洁与灵活,注意力成本高,易出错。本文介绍sklearn模块的GridSearchCV模块,能够在指定的范围内自动搜索具有不同超参数的不同模型组合,有效解放注意力。

GridSearchCV模块简介

这个模块是sklearn模块的子模块,导入方法非常简单

from sklearn.model_selection import GridSearchCV

函数原型:

class sklearn.model_selection.GridSearchCV(estimator, param_grid, scoring=None, fit_params=None, n_jobs=1, iid=True, refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs', error_score='raise', return_train_score=True)

其中cv可以是整数或者交叉验证生成器或一个可迭代器,cv参数对应的4种输入列举如下:

  1. None:默认参数,函数会使用默认的3折交叉验证
  2. 整数k:k折交叉验证。对于分类任务,使用StratifiedKFold(类别平衡,每类的训练集占比一样多,具体可以查看官方文档)。对于其他任务,使用KFold
  3. 交叉验证生成器:得自己写生成器
  4. 可以生成训练集与测试集的迭代器

分析结果自动保存

逗号分隔值(Comma-Separated Values,CSV,有时也称为字符分隔值,因为分隔字符也可以不是逗号),其文件以纯文本形式存储表格数据(数字和文本)。纯文本意味着该文件是一个,不含必须像二进制数字那样被解读的数据。CSV文件由任意数目的记录组成,记录间以某种换行符分隔;每条记录由字段组成,字段间的分隔符是其它字符或字符串,最常见的是逗号或制表符。通常,所有记录都有完全相同的字段序列。

CSV文件有个突出的优点,可以用excel等软件打开,比起记事本和matlab、python等编程语言界面,便于查看、制作报告、后期整理等。

GridSearchCV模块中,不同超参数的组合方式及其计算结果以字典的形式保存在 clf.cv_results_中,python的pandas模块提供了高效整理数据的方法,只需要3行代码即可解决问题。

cv_result = pd.DataFrame.from_dict(clf.cv_results_)
with open('cv_result.csv','w') as f:cv_result.to_csv(f)

完整例程

代码清晰易懂,无须解释。https://github.com/JiJingYu/tensorflow-exercise/tree/master/svm_grid_search

import pandas as pd
from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_reportiris = datasets.load_iris()
parameters = {'kernel':('linear', 'rbf'), 'C':[1, 2, 4], 'gamma':[0.125, 0.25, 0.5 ,1, 2, 4]}
svr = svm.SVC()
clf = GridSearchCV(svr, parameters, n_jobs=-1)
clf.fit(iris.data, iris.target)
cv_result = pd.DataFrame.from_dict(clf.cv_results_)
with open('cv_result.csv','w') as f:cv_result.to_csv(f)print('The parameters of the best model are: ')
print(clf.best_params_)y_pred = clf.predict(iris.data)
print(classification_report(y_true=iris.target, y_pred=y_pred))

http://www.taodudu.cc/news/show-1782047.html

相关文章:

  • 【physx/wasm】在physx中添加自定义接口并重新编译wasm
  • excel---常用操作
  • Lora训练Windows[笔记]
  • linux基础指令讲解(ls、pwd、cd、touch、mkdir)
  • InnoDB 事务处理机制
  • 启明云端ESP32 C3 模组WT32C3通过 MQTT 连接 AWS
  • 请实现一个函数,将一个字符串中的每个空格替换成...
  • 用两个栈来实现一个队列,完成队列的Push和Pop操作。 队列中的元素为int类型。
  • 我们可以用2*1的小矩形横着或者竖着去覆盖更大的矩形。请问用n个2*1的小矩形无重叠地覆盖一个2*n的大矩形,总共有多少种方法?
  • 3D Bounding Box Estimation Using Deep Learning and Geometry
  • Deep manta算法解析
  • GS3D An Efficient 3D Object Detection Framework for Autonomous Driving算法解析
  • python机器学习库xgboost使用调参
  • LightGBM算法解析
  • 机器学习模型之集成算法
  • ESPNet: Efficient Spatial Pyramid of Dilated Convolutions for Semantic Segmentation(自动驾驶领域轻量级模型)
  • 中缀表达式转后缀表达式(非常简单易懂)
  • 后缀表达式转中缀表达式(非常简单易懂)
  • 给定一列非负整数,求这些数连接起来能组成的最大的数。
  • 努力找工作中。。。
  • 手撕代码之快速排序算法(简单明了)
  • 小明很喜欢数学,有一天他在做数学作业时,要求计算出9~16的和,他马上就写出了正确答案是100。但是他并不满足于此,他在想究竟有多少种连续的正数序列的和为100(至少包括两个数)。
  • 小Q正在给一条长度为n的道路设计路灯安置方案。 为了让问题更简单,小Q把道路视为n个方格,需要照亮的地方用'.'表示, 不需要照亮的障碍物格子用'X'表示。
  • 牛牛以前在老师那里得到了一个正整数数对(x, y), 牛牛忘记他们具体是多少了。 但是牛牛记得老师告诉过他x和y均不大于n, 并且x除以y的余数大于等于k。 牛牛希望你能帮他计算一共有,,,
  • C++实现选择排序
  • C++实现希尔排序
  • 给出两个 非空 的链表用来表示两个非负的整数。其中,它们各自的位数是按照 逆序 的方式存储的,并且它们的每个节点只能存储 一位 数字。 如果,我们将这两个数相加起来,则会返回一个新的链表来表示,,,
  • CatBoost之算法解析(Kaggle常用模型)
  • ElasticNet算法解析
  • SVM支持向量机算法详解
  • SVM中的一些关键点解析
  • 朴素贝叶斯算法解析
  • Kmeans算法解析(非常详细)
  • DBSCAN(自适应密度聚类)算法解析
  • ID3、C4.5、CART决策树算法解析(关键内容讲解)
  • 面试之手撕BP反向传播