MLSR.ssl 源代码

from sklearn.model_selection import GridSearchCV, train_test_split
# from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import MinMaxScaler
from sklearn.pipeline import Pipeline
from sklearn.metrics import f1_score, accuracy_score
from sklearn.metrics import confusion_matrix
from joblib import dump, load
import numpy as np
from time import strftime, localtime
from .data import DataSet
from .plot import plot_confusion_matrix


[文档]def grid_search_and_result_ssl( Xtrain: np.ndarray, ytrain: np.ndarray, Xtest: np.ndarray, ytest: np.ndarray, pipe: Pipeline, grid: dict, log_dir: str, score=None, verbose: int = 2, k: int = 5): """ 交叉验证网格搜索,测试集和训练集得分,混淆矩阵和ROC曲线绘制 Args: Xtrain: 训练集特征 ytrain: 训练集标签 Xtest: 测试集特征 ytest: 测试集标签 pipe: 模型管道 grid: 超参数搜索空间 log_dir: 训练结果输出目录,注意一定要先创建该目录 score: 评分指标,默认使用f1和acc,最后用f1 refit verbose: 日志级别,0为静默 k: 交叉验证折数 Returns: 训练好的GridSearchCV模型 """ file_prefix = log_dir + '/' + strftime("%Y_%m_%d_%H_%M_%S", localtime()) file = open(file_prefix + '.log.txt', 'x') scoring = score if scoring is None: scoring = { 'f1': 'f1_macro', 'accuracy': 'accuracy' } gsCV = GridSearchCV( estimator=pipe, cv=k, n_jobs=-1, param_grid=grid, scoring=scoring, refit='f1', verbose=verbose ) gsCV.fit(Xtrain, ytrain) dump(gsCV, log_dir + '/gsCV') dump(gsCV.best_estimator_, log_dir + '/best_model') if verbose > 2: file.write(gsCV.cv_results_.__str__()) if verbose: file.write(gsCV.get_params().__str__()) file.write('\nBest score on training set by grid search cross validation: {}\n' .format(gsCV.score(Xtrain, ytrain))) best_model = load(log_dir + '/best_model') test_prediction = best_model.predict(Xtest) file.write('Accuracy on test set: {}\n'.format(accuracy_score(ytest, test_prediction))) file.write('F1-score on test set: {}\n'.format(f1_score(ytest, test_prediction, average='macro'))) if verbose: cm = confusion_matrix(ytrain, best_model.predict(Xtrain)) plot_confusion_matrix(cm, ['无标签', '0', '1'], file_prefix + '_train_cm.png') file.write('\ntrain_cm:\n') file.write(cm.__str__()) cm = confusion_matrix(ytest, test_prediction) plot_confusion_matrix(cm, ['0', '1'], file_prefix + '_test_cm.png') file.write('\ntest_cm:\n') file.write(cm.__str__()) # plot_roc(best_model, Xtest, ytest, file_prefix + '_roc.png') file.close() return gsCV
[文档]def do_tsvm(data: DataSet, log_dir: str = '../log', grid: dict = None): """ Transductive Support Vector Machine Args: data: 输入数据DataSet对象 grid:超参数搜索空间的网格,不填则使用默认搜索空间 log_dir:输出结果文件的目录 Returns:返回训练好的GridSearchCV模型 """ from .tsvm import TSVM if grid is None: grid = { 'tsvm__kernel': ['linear', 'poly', 'rbf', 'sigmoid'], 'tsvm__n_iter': [10, 50], 'tsvm__Cu': [0.0001, 0.001, 0.01] } pipe = Pipeline([ ('scaler', MinMaxScaler()), ('tsvm', TSVM()) ]) index = data.strong_label[data.strong_label != -1].index strong_features = data.features.take(index) strong_label = data.strong_label.take(index) index = data.strong_label[data.strong_label == -1].index unlabeled_features = data.features.take(index) unlabeled_strong_label = data.strong_label.take(index) Xtrain, Xtest, ytrain, ytest = train_test_split(strong_features, strong_label, train_size=0.4) Xtrain = np.concatenate((Xtrain, unlabeled_features.values), axis=0) ytrain = np.concatenate((ytrain, unlabeled_strong_label.values), axis=0) return grid_search_and_result_ssl(Xtrain, ytrain, Xtest, ytest, pipe, grid, log_dir)