-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_training.py
More file actions
122 lines (100 loc) · 4.13 KB
/
model_training.py
File metadata and controls
122 lines (100 loc) · 4.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc
import joblib
def main():
print("Loading the Breast Cancer Wisconsin dataset...")
data = load_breast_cancer()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = data.target
# Note: 0 is Malignant, 1 is Benign in original dataset. Common practice is to set 1 for positive class (malignant).
# target_names: ['malignant' 'benign'] means target 0 is malignant.
# We will keep it as is, but be mindful during evaluation. Let's invert it so 1 = Malignant to match medical intuition.
y = np.where(y == 0, 1, 0)
print("Dataset shape:", X.shape)
# Preprocessing
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# Train test split (80/20)
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42, stratify=y)
print(f"Training set: {X_train.shape[0]} samples, Test set: {X_test.shape[0]} samples.")
# Model Initialization
models = {
'Logistic Regression': LogisticRegression(random_state=42),
'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42),
'Support Vector Machine': SVC(probability=True, random_state=42)
}
# Training and Evaluation
best_model_name = ""
best_f1 = 0
best_model = None
results = []
plt.figure(figsize=(10, 8))
for name, model in models.items():
print(f"\nTraining {name}...")
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
y_proba = model.predict_proba(X_test)[:, 1]
acc = accuracy_score(y_test, y_pred)
prec = precision_score(y_test, y_pred)
rec = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
results.append({
'Model': name,
'Accuracy': acc,
'Precision': prec,
'Recall': rec,
'F1-Score': f1
})
print(f"Metrics for {name}:")
print(f" Accuracy: {acc:.4f}")
print(f" Precision: {prec:.4f}")
print(f" Recall: {rec:.4f}")
print(f" F1-Score: {f1:.4f}")
# ROC Curve
fpr, tpr, _ = roc_curve(y_test, y_proba)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f'{name} (AUC = {roc_auc:.2f})')
if f1 > best_f1:
best_f1 = f1
best_model_name = name
best_model = model
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC)')
plt.legend(loc="lower right")
plt.savefig('roc_curve_comparison.png')
plt.close()
print(f"\nBest Model: {best_model_name} with F1-Score: {best_f1:.4f}")
# Feature Importance (using Random Forest)
print("\nExtracting Feature Importances using Random Forest...")
rf_model = models['Random Forest']
importances = rf_model.feature_importances_
indices = np.argsort(importances)[::-1]
plt.figure(figsize=(12, 6))
plt.title("Feature Importances (Random Forest)")
plt.bar(range(X.shape[1]), importances[indices], align="center")
plt.xticks(range(X.shape[1]), np.array(data.feature_names)[indices], rotation=90)
plt.xlim([-1, X.shape[1]])
plt.tight_layout()
plt.savefig('feature_importance.png')
plt.close()
# Save best model and scaler
joblib.dump(best_model, 'best_model.pkl')
joblib.dump(scaler, 'scaler.pkl')
# Save feature names for the app
joblib.dump(list(data.feature_names), 'feature_names.pkl')
print("\nModel, scaler, and feature names saved successfully.")
if __name__ == "__main__":
main()