@@ -217,4 +217,37 @@ def heatmap(df, n:int,target:str,columns:None):
217217
218218 plt .figure (figsize = (20 ,10 ))
219219 hm = sns .heatmap (cm , cbar = True , annot = True , cmap = 'YlOrBr' , fmt = '.2f' , yticklabels = cols .values , xticklabels = cols .values )
220- return hm
220+ return hm
221+
222+ def plot_roc_curve (y_true , y_pred , pos_label = 1 , figsize = (8 , 8 )):
223+ '''
224+ Function to plot the ROC curve of a binary classifier
225+
226+ Parameters:
227+
228+ y_true: true labels
229+ y_pred: model predictions
230+ pos_label: positive label (default: 1)
231+ figsize: figure size (default: (8, 8))
232+
233+ Returns:
234+ Lineplot of the ROC curve
235+
236+ '''
237+ # Compute the false positive rate, true positive rate, and thresholds
238+ fpr , tpr , thresholds = roc_curve (y_true , y_pred , pos_label = pos_label )
239+
240+ # Compute the area under the curve (AUC)
241+ roc_auc = auc (fpr , tpr )
242+
243+ # Create the ROC curve plot
244+ plt .figure (figsize = figsize )
245+ plt .plot (fpr , tpr , color = 'darkorange' , lw = 2 , label = 'ROC curve (AUC = %0.2f)' % roc_auc )
246+ plt .plot ([0 , 1 ], [0 , 1 ], color = 'navy' , lw = 2 , linestyle = '--' )
247+ plt .xlim ([0.0 , 1.0 ])
248+ plt .ylim ([0.0 , 1.05 ])
249+ plt .xlabel ('False Positive Rate' )
250+ plt .ylabel ('True Positive Rate' )
251+ plt .title ('Receiver operating characteristic (ROC) curve' )
252+ plt .legend (loc = "lower right" )
253+ plt .show ()
0 commit comments