Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: replace traversal_paths with access_paths #14

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/unit/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_pooling_strategy(pooling_strategy: str):


@pytest.mark.parametrize(
'traversal_paths, counts',
'access_paths, counts',
[
('@r', [['@r', 1], ['@c', 0], ['@cc', 0]]),
('@c', [['@r', 0], ['@c', 3], ['@cc', 0]]),
Expand All @@ -108,7 +108,7 @@ def test_pooling_strategy(pooling_strategy: str):
],
)
def test_traversal_path(
traversal_paths: str, counts: List, basic_encoder: TransformerTorchEncoder
access_paths: str, counts: List, basic_encoder: TransformerTorchEncoder
):
text = 'blah'
docs = DocumentArray([Document(id='root1', text=text)])
Expand All @@ -122,7 +122,7 @@ def test_traversal_path(
Document(id='chunk112', text=text),
]

basic_encoder.encode(docs=docs, parameters={'traversal_paths': traversal_paths})
basic_encoder.encode(docs=docs, parameters={'access_paths': access_paths})
for path, count in counts:
embeddings = docs[path].embeddings
if count != 0:
Expand Down
21 changes: 15 additions & 6 deletions transform_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from jina import DocumentArray, Executor, requests

import warnings

class TransformerTorchEncoder(Executor):
"""The TransformerTorchEncoder encodes sentences into embeddings using transformers models."""
Expand All @@ -23,7 +24,8 @@ def __init__(
max_length: Optional[int] = None,
embedding_fn_name: str = '__call__',
device: str = 'cpu',
traversal_paths: str = '@r',
access_paths: str = '@r',
traversal_paths: Optional[str] = None,
batch_size: int = 32,
*args,
**kwargs,
Expand All @@ -41,14 +43,21 @@ def __init__(
default the max length supported by the model will be used.
:param embedding_fn_name: Function to call on the model in order to get output
:param device: Torch device to put the model on (e.g. 'cpu', 'cuda', 'cuda:1')
:param traversal_paths: Used in the encode method an define traversal on the
:param access_paths: Used in the encode method an define traversal on the
received `DocumentArray`
:param traversal_paths: please use access_paths
:param batch_size: Defines the batch size for inference on the loaded
PyTorch model.
"""
super().__init__(*args, **kwargs)

self.traversal_paths = traversal_paths
if traversal_paths is not None:
self.access_paths = traversal_paths
warnings.warn("'traversal_paths' will be deprecated in the future, please use 'access_paths'.",
DeprecationWarning,
stacklevel=2)
else:
self.access_paths = access_paths
self.batch_size = batch_size

base_tokenizer_model = base_tokenizer_model or pretrained_model_name_or_path
Expand All @@ -73,16 +82,16 @@ def encode(self, docs: DocumentArray, parameters: Dict={}, **kwargs):
each Document.

:param docs: DocumentArray containing text
:param parameters: dictionary to define the `traversal_paths` and the
:param parameters: dictionary to define the `access_paths` and the
`batch_size`. For example,
`parameters={'traversal_paths': 'r', 'batch_size': 10}`.
`parameters={'access_paths': 'r', 'batch_size': 10}`.
:param kwargs: Additional key value arguments.
"""

docs_batch_generator = DocumentArray(
filter(
lambda x: bool(x.text),
docs[parameters.get('traversal_paths', self.traversal_paths)],
docs[parameters.get('access_paths', self.access_paths)],
)
).batch(batch_size=parameters.get('batch_size', self.batch_size))

Expand Down