Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions task_decomposition/analysis/compare_GPT_groundtruth.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,56 @@ def subtask_similarity(
score["total"] = TEMPORAL_WEIGHT * temporal_score + SEMANTIC_WEIGHT * semantic_score
return score

def subtask_similarity2(
subtask_decomp_A: list, subtask_decomp_B: list, DEBUG: bool = True
):
"""Com"""
assert len(subtask_decomp_A) > 0 and len(subtask_decomp_B) > 0
assert subtask_decomp_A[0][START_IDX] == subtask_decomp_B[0][START_IDX]
assert subtask_decomp_A[-1][END_IDX] == subtask_decomp_B[-1][END_IDX]

TEMPORAL_WEIGHT = 0.5
SEMANTIC_WEIGHT = 0.5

# TODO: assert subtask_decomp_A and subtask_decomp_B are non-overlapping
N = subtask_decomp_A[-1][END_IDX] + 1 # Assuming the last index is end index
temporal_score = 0
semantic_score = 0
temporal_weight_sum = 0
for subtask_a in subtask_decomp_A:
for subtask_b in subtask_decomp_B:
if intersection(subtask_a, subtask_b):
IOU = get_IOU(subtask_a, subtask_b)
_temporal = IOU
# REMOVED IOU
_semantic = (
semantic_distance(
subtask_a[SUBTASK_NAME_IDX], subtask_b[SUBTASK_NAME_IDX]
)
)
# Apply weight based on intersection length relative to total task length
temporal_weight = (
max(subtask_a[END_IDX], subtask_b[END_IDX])
- min(subtask_a[START_IDX], subtask_b[START_IDX])
+ 1
) / N
# semantic_weight = (
# min(subtask_a[END_IDX], subtask_b[END_IDX])
# - max(subtask_a[START_IDX], subtask_b[START_IDX])
# + 1
# ) / N
temporal_score += _temporal * temporal_weight
semantic_score += (
_semantic * temporal_weight
) # or without IOU and just semantic_weight
temporal_weight_sum += temporal_weight

score = {}
# Normalized the score with the temporal weight sum
score["temporal"] = temporal_score/temporal_weight_sum
score["semantic"] = semantic_score/temporal_weight_sum
score["total"] = TEMPORAL_WEIGHT * temporal_score + SEMANTIC_WEIGHT * semantic_score
return score

# %%
def test():
Expand Down