-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update argilla
integration to use argilla_sdk
v2
#705
base: develop
Are you sure you want to change the base?
Conversation
@@ -20,23 +20,24 @@ | |||
from pydantic import Field, PrivateAttr, SecretStr | |||
|
|||
try: | |||
import argilla as rg | |||
import argilla_sdk as rg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@frascuchon What do you think the timeline on renaming argilla_sdk
>> argilla
is? Is it worth waiting for the changes in argilla
so that it's in line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes are minimal. We can leave it as is until we havemore stable versions of argilla v2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice start. Just a few high level comments.
Co-authored-by: Ben Burtenshaw <[email protected]>
For the moment it's being installed as `pip install git+https://github.com/argilla-io/argilla-python.git@main`
Edit: the issue was with the Argilla Server version as I was using 1.26.0 while 1.27.0 or higher was required 👍🏻
Install as from uuid import uuid4
from distilabel.llms import InferenceEndpointsLLM
from distilabel.pipeline.local import Pipeline
from distilabel.steps import (
LoadDataFromDicts,
TextGenerationToArgilla,
)
from distilabel.steps.tasks import TextGeneration
if __name__ == "__main__":
with Pipeline(name="my-pipeline") as pipeline:
load_dataset = LoadDataFromDicts(
name="load_dataset",
data=[
{
"instruction": "Write a short story about a dragon that saves a princess from a tower.",
},
],
)
text_generation = TextGeneration(
name="text_generation",
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
api_key="...", # type: ignore
),
num_generations=4,
group_generations=True,
)
text_generation_to_argilla = TextGenerationToArgilla(
name="text_generation_to_argilla",
api_url="...",
api_key="...", # type: ignore
dataset_name=f"text-generation-{uuid4()}",
dataset_workspace="admin",
)
( # type: ignore
load_dataset
>> text_generation
>> text_generation_to_argilla
)
pipeline.run(
parameters={
text_generation.name: { # type: ignore
"llm": {
"generation_kwargs": {
"max_new_tokens": 512,
"temperature": 0.7,
},
},
},
}
) The logs then look like: |
return ( | ||
True | ||
if self._client.datasets( # type: ignore | ||
name=self.dataset_name, # type: ignore | ||
workspace=self._client.workspaces(name=self.dataset_workspace) # type: ignore | ||
if self.dataset_workspace is not None | ||
else None, | ||
).exists() | ||
is not None | ||
else False | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dataset.exists() should return already a bool value. So maybe extra checks are not needed.
return ( | |
True | |
if self._client.datasets( # type: ignore | |
name=self.dataset_name, # type: ignore | |
workspace=self._client.workspaces(name=self.dataset_workspace) # type: ignore | |
if self.dataset_workspace is not None | |
else None, | |
).exists() | |
is not None | |
else False | |
) | |
workspace = ( | |
self._client.workspaces(name=self.dataset_workspace) # type: ignore | |
if self.dataset_workspace is not None | |
else None | |
) | |
return ( | |
self._client.datasets( # type: ignore | |
name=self.dataset_name, # type: ignore | |
workspace=workspace, | |
).exists() | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't do it that way since it was returning None
not sure if that is still a thing within the version in the argilla
repository
], | ||
MockDataset = rg.Dataset( | ||
name="dataset", | ||
workspace=rg.Workspace("workspace"), # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@burtenshaw I think this is a side effect of calling API on Dataset.__init__
. Maybe we should review to force client-dependant actions only for methods calling the API and not init ones (for example workspace is only needed when creating the dataset)
Description
This PR renames and updates
Argilla
toArgillaBase
, since now the client inargilla_sdk
(later to be renamed toargilla
only as per a recent discussion with @frascuchon) is namedArgilla
too. Besides that, the code has been updated to use the latest Python client instead not only forArgillaBase
but also for the subclassesTextGenerationToArgilla
andPreferenceToArgilla
.Warning
This change here implies that the
argilla
server version should be 1.27.0 or higher, otherwise theargilla_sdk
won't work.Closes argilla-io/argilla#4880