diff --git a/Chapter 06/thompson_sampling.py b/Chapter 06/thompson_sampling.py index 03313f38..7096bc3b 100644 --- a/Chapter 06/thompson_sampling.py +++ b/Chapter 06/thompson_sampling.py @@ -24,6 +24,8 @@ total_reward_ts = 0 numbers_of_rewards_1 = [0] * d numbers_of_rewards_0 = [0] * d +# collect data for plot +results = [] for n in range(0, N): # Random Selection strategy_rs = random.randrange(d) @@ -45,7 +47,10 @@ numbers_of_rewards_0[strategy_ts] = numbers_of_rewards_0[strategy_ts] + 1 strategies_selected_ts.append(strategy_ts) total_reward_ts = total_reward_ts + reward_ts - + # collect data for plot every x iterations + if n%50==0: + results.append(numbers_of_rewards_1.copy()) + # Computing the Relative Return relative_return = (total_reward_ts - total_reward_rs) / total_reward_rs * 100 print("Relative Return: {:.0f} %".format(relative_return)) @@ -56,3 +61,15 @@ plt.xlabel('Strategy') plt.ylabel('Number of times the strategy was selected') plt.show() + +# Plotting animated plot +from matplotlib.animation import FuncAnimation +fig, ax = plt.subplots() +fig.set_tight_layout(True) + +def update(i): + return plt.plot(range(9), results[i]) +# FuncAnimation will call the 'update' function for each frame; here + # animating over all frames, with an interval of 200ms between frames. +anim = FuncAnimation(fig, update, frames=np.arange(0, len(results)), interval=200) +plt.show()