Skip to content

Commit 471a3c5

Browse files
committed
functions Ramón
1 parent bb179d6 commit 471a3c5

2 files changed

Lines changed: 82 additions & 0 deletions

File tree

toolkit/data_processing.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,3 +328,69 @@ def _exponential_smooth(data, alpha):
328328
smoothed_data.iloc[-1] = data.iloc[-1]
329329

330330
return smoothed_data
331+
332+
333+
def add_previous(df, n, clas, values):
334+
"""
335+
Add columns to the dataframe with the values of the last n events for each class.
336+
337+
Parameters:
338+
df (DataFrame): The input dataframe.
339+
n (int): The number of previous events to include in the output dataframe.
340+
clas (str): name of the column you want to obtain the previous values of.
341+
values (str): name of the column whose previous values you need
342+
343+
Returns:
344+
None
345+
"""
346+
# Group the dataframe by the clas column
347+
grouped = df.groupby(clas)
348+
349+
# Initialize a list to store the shifted values
350+
shifted_values = []
351+
352+
# Shift the values within each group n times to get the previous values of your clas
353+
for i in range(1, n + 1):
354+
shifted_values.append(grouped[values].shift(i))
355+
356+
# Concatenate the shifted values with the original dataframe
357+
new_cols = [f'Previous_value-{i}' for i in range(1, n + 1)]
358+
for i, col in enumerate(new_cols):
359+
df[col] = shifted_values[i]
360+
361+
return df
362+
363+
def winner_loser(x, df, column):
364+
365+
"""
366+
Comparator of odd and even rows, checks which one is a bigger value and returns Victory, Loss or Draw
367+
according to that. Prepared for sports, but appliable to other uses.
368+
369+
Args:
370+
x (int): number of the index
371+
df (df): dataframe to work in
372+
column (str): name of the column we want to compare
373+
374+
Return: Victory, Draw or Loss
375+
376+
"""
377+
if (x+2) % 2 == 0:
378+
if df[column][x] > df[column][x+1]:
379+
x = 'Victory'
380+
return x
381+
elif df[column][x] < df[column][x+1]:
382+
x = 'Loss'
383+
return x
384+
else:
385+
x = 'Draw'
386+
return x
387+
if (x+2) % 2 != 0:
388+
if df[column][x] > df[column][x-1]:
389+
x = 'Victory'
390+
return x
391+
elif df[column][x] < df[column][x-1]:
392+
x = 'Loss'
393+
return x
394+
else:
395+
x = 'Draw'
396+
return x

toolkit/machine_learning.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,3 +295,19 @@ def download_image(download_path, url, file_name):
295295
download_image(download_dir, url, str(i) + ".jpg")
296296

297297
wd.quit()
298+
299+
300+
def worst_params(gridsearch):
301+
'''
302+
Function to obtain the worst params of a gridsearch. In case we need to train a gridsearch multiple times,
303+
it can be useful to know which parameters are likely to be deleted, in order to make our training faster.
304+
305+
Args:
306+
gridsearch: trained gridsearch
307+
308+
'''
309+
position = list(gridsearch['rank_test_score']).index(gridsearch['rank_test_score'].max())
310+
worst_params = gridsearch['params'][position]
311+
worst_scoring = gridsearch['mean_test_score'][position]
312+
313+
return str(worst_params), worst_scoring

0 commit comments

Comments
 (0)