Skip to content

Batch classification fails when using data_path in PytorchWildlife/models/classification/timm_base/base_classifier.py #611

@aweaver1fandm

Description

@aweaver1fandm

Search before asking

  • I have searched the Pytorch-Wildlife issues and found no similar bug report.

Bug

Using DFNE classifier but the problem is with:
PytorchWildlife/models/classification/timm_base/base_classifier

In that file the setup for the dataset for batch analysis is:
if data_path:
dataset = pw_data.ImageFolder(
data_path,
transform=self.transform,
path_head='.'
)

Causing an error because ImageFolder has no path_head attribute/keyword argument.

This is what you have for ImageFolder(notice no path_head option):

class ImageFolder(Dataset):
"""
A PyTorch Dataset for loading images from a specified directory.
Each item in the dataset is a tuple containing the image data,
the image's path, and the original size of the image.
"""

def __init__(self, image_dir, transform=None):
    """
    Initializes the dataset.

    Parameters:
        image_dir (str): Path to the directory containing the images.
        transform (callable, optional): Optional transform to be applied on the image.
    """
    super(ImageFolder, self).__init__()
    self.image_dir = image_dir
    self.transform = transform
    self.images = [os.path.join(dp, f) for dp, dn, filenames in os.walk(image_dir) for f in filenames if is_image_file(f)] # dp: directory path, dn: directory name, f: filename

Environment

No response

Minimal Reproducible Example

No response

Additional

No response

Are you willing to submit a PR?

  • Yes I'd like to help by submitting a PR!

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    Status

    Backlog

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions