Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update LAION-example and default value of scheduled_tasks #1092

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/torchdata.datapipes.iter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ We have different types of Iterable DataPipes:
6. IO - interacting with the file systems or remote server (e.g. downloading, opening,
saving files, and listing the files in directories).

7. Mapping - apply the a given function to each element in the DataPipe.
7. Mapping - apply a given function to each element in the DataPipe.

8. Others - perform miscellaneous set of operations.

Expand Down Expand Up @@ -156,7 +156,7 @@ saving files, and listing the files in directories).
Mapping DataPipes
-------------------------

These DataPipes apply the a given function to each element in the DataPipe.
These DataPipes apply a given function to each element in the DataPipe.

.. autosummary::
:nosignatures:
Expand Down
31 changes: 15 additions & 16 deletions examples/vision/laion5b.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from torchdata.datapipes.iter import HuggingFaceHubReader

try:
import PIL
from PIL import Image
except ImportError:
PIL = None
Image = None
raise ModuleNotFoundError(
"Package `PIL` is required to be installed to run this example. "
"Please use `pip install Pillow` or `conda install -c anaconda pillow` to install it."
)


def has_no_watermark(x):
Expand All @@ -29,6 +30,7 @@ def is_sfw(x):
def load_image(url):
try:
r = requests.get(url, timeout=5)
r.raise_for_status()
return Image.open(BytesIO(r.content))
except Exception:
return None
Expand All @@ -50,7 +52,7 @@ def laion2b_en(name=NAME):
dp = dp.filter(is_sfw)
dp = dp.shuffle().sharding_filter()
dp = dp.slice(index=["TEXT", "URL"])
dp = dp.map(fn=load_image, input_col="URL", output_col="IMAGE") # this needs multithreading
dp = dp.threadpool_map(fn=load_image, input_col="URL", output_col="IMAGE")
dp = dp.filter(filter_fn=image_was_loaded, input_col="IMAGE")
dp = dp.drop("URL")
dp = dp.batch(20)
Expand All @@ -59,18 +61,15 @@ def laion2b_en(name=NAME):

def print_label_and_copyright(label, image):
try:
try:
exif = image.getexif()
# 0x8298 is the EXIF-tag for copyright
copyright_info = exif.get(0x8298, "no info")
except Exception:
copyright_info = "EXIF data is corrupted"
if copyright_info != "no info" and copyright_info != "EXIF data is corrupted":
print(f"image {i}: {label=}, {copyright_info=} ")
else:
print(f"image {i}: {label=}")
except PIL.UnidentifiedImageError:
print(f"image {i}: corrupted")
exif = image.getexif()
# 0x8298 is the EXIF-tag for copyright
copyright_info = exif.get(0x8298, "no info")
except Exception:
copyright_info = "EXIF data is corrupted"
if copyright_info != "no info" and copyright_info != "EXIF data is corrupted":
print(f"image {i}: {label=}, {copyright_info=} ")
else:
print(f"image {i}: {label=}")


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions torchdata/datapipes/iter/transform/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ class ThreadPoolMapperIterDataPipe(IterDataPipe[T_co]):
- Integer is used for list/tuple. ``-1`` represents to append result at the end.
- Key is used for dict. New key is acceptable.

scheduled_tasks: How many tasks will be scheduled at any given time (Default value: 128)
scheduled_tasks: How many tasks will be scheduled at any given time (Default value: 500)
max_workers: Maximum number of threads to execute function calls
**threadpool_kwargs: additional arguments to be given to the ``ThreadPoolExecutor``

Expand All @@ -634,6 +634,7 @@ class ThreadPoolMapperIterDataPipe(IterDataPipe[T_co]):

However, too high value of ``scheduled_tasks`` might lead to long waiting period until the first element is yielded
as ``next`` is called ``scheduled_tasks`` many times on ``source_datapipe`` before yielding.
Additionally, will lead to higher memory utilization.

We encourage you to try out different values of ``max_workers`` and ``scheduled_tasks``
in search for optimal values for your use-case.
Expand Down Expand Up @@ -695,7 +696,7 @@ def __init__(
fn: Callable,
input_col=None,
output_col=None,
scheduled_tasks: int = 128,
scheduled_tasks: int = 500,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, this highly depends on the size of request and response.
If we are loading video files, such a high number might not provide the best perf

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm scheduled_tasks negatively affects performance if it is too low, yes. Higher values will increase memory usage (e.g. if the current video takes a lot longer to download than the next few which are then prefetched and stored in memory) but in this case one is limited by internet connection speed anyway. Why should it negatively affect performance (throughput) if it's too high?

max_workers: Optional[int] = None,
**threadpool_kwargs,
) -> None:
Expand Down