|
58 | 58 | "execution_count": null, |
59 | 59 | "metadata": {}, |
60 | 60 | "outputs": [], |
61 | | - "source": [ |
62 | | - "def create_interactive_demo():\n", |
63 | | - " \"\"\"Create an interactive widget for exploring piecewise-constant behavior.\"\"\"\n", |
64 | | - "\n", |
65 | | - " # Create sliders for data generation\n", |
66 | | - " n_samples_slider = widgets.IntSlider(\n", |
67 | | - " value=10, min=5, max=20, step=1,\n", |
68 | | - " description='N Samples:'\n", |
69 | | - " )\n", |
70 | | - "\n", |
71 | | - " pos_ratio_slider = widgets.FloatSlider(\n", |
72 | | - " value=0.5, min=0.1, max=0.9, step=0.1,\n", |
73 | | - " description='Pos Ratio:'\n", |
74 | | - " )\n", |
75 | | - "\n", |
76 | | - " seed_slider = widgets.IntSlider(\n", |
77 | | - " value=42, min=0, max=100, step=1,\n", |
78 | | - " description='Random Seed:'\n", |
79 | | - " )\n", |
80 | | - "\n", |
81 | | - " metric_dropdown = widgets.Dropdown(\n", |
82 | | - " options=['f1', 'accuracy', 'precision', 'recall'],\n", |
83 | | - " value='f1',\n", |
84 | | - " description='Metric:'\n", |
85 | | - " )\n", |
86 | | - "\n", |
87 | | - " def update_plot(n_samples, pos_ratio, seed, metric):\n", |
88 | | - " # Generate random data\n", |
89 | | - " np.random.seed(seed)\n", |
90 | | - " n_pos = int(n_samples * pos_ratio)\n", |
91 | | - " n_neg = n_samples - n_pos\n", |
92 | | - "\n", |
93 | | - " y_true = np.concatenate([np.zeros(n_neg), np.ones(n_pos)])\n", |
94 | | - " y_prob = np.random.beta(2, 2, n_samples) # Bell-shaped distribution\n", |
95 | | - "\n", |
96 | | - " # Sort by probability for cleaner visualization\n", |
97 | | - " sort_idx = np.argsort(y_prob)\n", |
98 | | - " y_true = y_true[sort_idx]\n", |
99 | | - " y_prob = y_prob[sort_idx]\n", |
100 | | - "\n", |
101 | | - " # Plot\n", |
102 | | - " plt.clf()\n", |
103 | | - " fig, opt_thresh, opt_score = plot_piecewise_metric(\n", |
104 | | - " y_true, y_prob, metric,\n", |
105 | | - " title_suffix=f'\\n{n_samples} samples, {len(np.unique(y_prob))} unique probabilities'\n", |
106 | | - " )\n", |
107 | | - "\n", |
108 | | - " print(f\"Generated {n_samples} samples ({n_pos} positive, {n_neg} negative)\")\n", |
109 | | - " print(f\"Optimal {metric} threshold: {opt_thresh:.3f} (score = {opt_score:.3f})\")\n", |
110 | | - " print(f\"Number of breakpoints: {len(np.unique(y_prob))}\")\n", |
111 | | - "\n", |
112 | | - " # Create interactive widget\n", |
113 | | - " interactive_plot = widgets.interactive(\n", |
114 | | - " update_plot,\n", |
115 | | - " n_samples=n_samples_slider,\n", |
116 | | - " pos_ratio=pos_ratio_slider,\n", |
117 | | - " seed=seed_slider,\n", |
118 | | - " metric=metric_dropdown\n", |
119 | | - " )\n", |
120 | | - "\n", |
121 | | - " display(interactive_plot)\n", |
122 | | - "\n", |
123 | | - "create_interactive_demo()" |
124 | | - ] |
| 61 | + "source": "def create_static_demo():\n \"\"\"Create static examples showing piecewise-constant behavior with different data characteristics.\"\"\"\n \n print(\"📊 STATIC EXAMPLES: Different Data Characteristics\")\n print(\"=\" * 55)\n \n # Example 1: Small imbalanced dataset \n print(\"\\n1️⃣ Small Imbalanced Dataset (5 samples, 20% positive)\")\n np.random.seed(42)\n y_ex1 = np.array([0, 0, 0, 1, 1])\n p_ex1 = np.array([0.1, 0.3, 0.4, 0.7, 0.9])\n fig1, opt1, score1 = plot_piecewise_metric(y_ex1, p_ex1, 'f1', \n title_suffix='\\nSmall Imbalanced Dataset')\n print(f\" → Optimal F1: {opt1:.3f} (score = {score1:.3f})\")\n print(f\" → Breakpoints: {len(np.unique(p_ex1))} unique probabilities\")\n \n # Example 2: Larger balanced dataset\n print(\"\\n2️⃣ Larger Balanced Dataset (20 samples, ~50% positive)\")\n np.random.seed(123)\n y_ex2 = np.random.randint(0, 2, 20)\n p_ex2 = np.random.beta(2, 2, 20) # Bell-shaped distribution\n # Sort for cleaner visualization\n sort_idx = np.argsort(p_ex2)\n y_ex2, p_ex2 = y_ex2[sort_idx], p_ex2[sort_idx]\n \n fig2, opt2, score2 = plot_piecewise_metric(y_ex2, p_ex2, 'f1', \n title_suffix='\\nLarger Balanced Dataset')\n print(f\" → Optimal F1: {opt2:.3f} (score = {score2:.3f})\")\n print(f\" → Breakpoints: {len(np.unique(p_ex2))} unique probabilities\")\n \n # Example 3: Precision vs Recall trade-off\n print(\"\\n3️⃣ Precision vs Recall Comparison\")\n y_ex3 = np.array([0, 0, 1, 1, 0, 1, 0, 1])\n p_ex3 = np.array([0.1, 0.3, 0.4, 0.6, 0.65, 0.8, 0.85, 0.9])\n \n # Compare different metrics on same data\n metrics_to_compare = ['precision', 'recall', 'f1']\n print(f\" Data: {len(y_ex3)} samples, {y_ex3.sum()} positive\")\n \n for metric in metrics_to_compare:\n result = optimize_thresholds(y_ex3, p_ex3, metric=metric)\n optimal_thresh = result.thresholds[0]\n optimal_score = _metric_score(y_ex3, p_ex3, optimal_thresh, metric)\n print(f\" → {metric.capitalize()}: t={optimal_thresh:.3f}, score={optimal_score:.3f}\")\n \n # Plot the trade-off\n thresholds = np.linspace(0.05, 0.95, 100)\n precision_scores = [_metric_score(y_ex3, p_ex3, t, 'precision') for t in thresholds]\n recall_scores = [_metric_score(y_ex3, p_ex3, t, 'recall') for t in thresholds]\n f1_scores = [_metric_score(y_ex3, p_ex3, t, 'f1') for t in thresholds]\n \n fig, ax = plt.subplots(1, 1, figsize=(12, 6))\n ax.plot(thresholds, precision_scores, 'g-', linewidth=2, label='Precision')\n ax.plot(thresholds, recall_scores, 'r-', linewidth=2, label='Recall') \n ax.plot(thresholds, f1_scores, 'b-', linewidth=2, label='F1 Score')\n \n # Mark optimal points\n for metric, color in zip(['precision', 'recall', 'f1'], ['green', 'red', 'blue']):\n result = optimize_thresholds(y_ex3, p_ex3, metric=metric)\n opt_t = result.thresholds[0]\n opt_s = _metric_score(y_ex3, p_ex3, opt_t, metric)\n ax.scatter([opt_t], [opt_s], color=color, s=150, marker='*', \n edgecolors='black', zorder=5)\n \n ax.set_xlabel('Decision Threshold')\n ax.set_ylabel('Metric Score')\n ax.set_title('Precision vs Recall Trade-off\\nStars show optimal thresholds for each metric')\n ax.grid(True, alpha=0.3)\n ax.legend()\n ax.set_ylim(0, 1.05)\n \n plt.tight_layout()\n plt.show()\n \n print(\"\\n💡 Key Insights:\")\n print(\" • Precision optimal: High threshold (fewer false positives)\")\n print(\" • Recall optimal: Low threshold (fewer false negatives)\") \n print(\" • F1 optimal: Balanced trade-off between precision and recall\")\n\n# Run the static demo\ncreate_static_demo()" |
125 | 62 | }, |
126 | 63 | { |
127 | 64 | "cell_type": "markdown", |
|
0 commit comments