Skip to content

Error when running example MNIST_fishleg_CNN script on GPU #33

@WeiShengL

Description

@WeiShengL

Using "cuda" devices for examples/MNIST_fishleg_CNN.py script (by commenting out line 31) gives an error in fishleg update_aux method about mismatch data type as follows:

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

I pressumed this is because the aux_dataloader is initiated to load data onto cpu and if the model was initiated on a gpu device, we get a mismatch in data.

This could be fix by perhaps adding an additional line of code to move data to specified device.
Original code:

aux_loader = torch.utils.data.DataLoader(
    train_dataset, shuffle=True, batch_size=batch_size
)

New code:

aux_loader = torch.utils.data.DataLoader(
    train_dataset, shuffle=True, batch_size=batch_size, 
    collate_fn=lambda x: tuple(x_.to(device) for x_ in default_collate(x))
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions