Skip to content

Error for template code in Task Group API (with python 3.12) #20

@Hazzng

Description

@Hazzng
# Stream all results from the group
async def get_all_results(client: parallel.AsyncParallel, taskgroup_id: str):
    results = []

    path = f"/v1beta/tasks/groups/{taskgroup_id}/runs"
    path += "?include_input=true&include_output=true"

    result_stream = await client.get(
        path=path,
        cast_to=TaskRunEvent | ErrorResponse | None,
        stream=True,
        stream_cls=parallel.AsyncStream[TaskRunEvent | ErrorResponse],
    )

    async for event in result_stream:
        if isinstance(event, TaskRunEvent) and event.output:
            company_input = CompanyInput.model_validate(event.input.input)
            company_output = CompanyOutput.model_validate(event.output.content)

            results.append(
                {
                    "company": company_input.company_name,
                    "insights": company_output.key_insights,
                    "market_position": company_output.market_position,
                }
            )

    return results

results = await get_all_results(client, taskgroup_id)
print(f"Processed {len(results)} companies successfully")

This code doesnt work for python 3.12 as this line cast_to=TaskRunEvent | ErrorResponse | None use weak ref type so it raised Error as TypeError: cannot create weak reference to 'types.UnionType' object.

Code fix: using strong UNION type to resolve this issue

from typing import Union

async def get_all_results(client: parallel.AsyncParallel, taskgroup_id: str):
    results = []

    path = f"/v1beta/tasks/groups/{taskgroup_id}/runs"
    path += "?include_input=true&include_output=true"

    result_stream = await client.get(
        path=path,
        cast_to=Union[TaskRunEvent, ErrorResponse, None],  # Use typing.Union instead of |
        stream=True,
        stream_cls=parallel.AsyncStream[Union[TaskRunEvent, ErrorResponse]],  # Also update this
    )

    async for event in result_stream:
          if isinstance(event, TaskRunEvent) and event.output:
              company_input = CompanyInput.model_validate(event.input.input)
              company_output = CompanyOutput.model_validate(event.output.content)

              results.append(
                  {
                      "vendor_gold_key": company_input.vendor_gold_key,
                      "vendor_name": company_input.company_name,
                      "tax_code": company_input.tax_code,
                      "country_code": company_input.country_code,
                      "research_summary": company_output.company_research_summary,
                      "is_exclusion": company_output.is_exclusion,
                  }
              )
          return results

results = await get_all_results(client, taskgroup_id)
print(f"Processed {len(results)} companies successfully")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions