Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update argilla integration to use argilla_sdk v2 #705

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from

Conversation

alvarobartt
Copy link
Member

@alvarobartt alvarobartt commented Jun 6, 2024

Description

This PR renames and updates Argilla to ArgillaBase, since now the client in argilla_sdk (later to be renamed to argilla only as per a recent discussion with @frascuchon) is named Argilla too. Besides that, the code has been updated to use the latest Python client instead not only for ArgillaBase but also for the subclasses TextGenerationToArgilla and PreferenceToArgilla.

Warning

This change here implies that the argilla server version should be 1.27.0 or higher, otherwise the argilla_sdk won't work.

Closes argilla-io/argilla#4880

@alvarobartt alvarobartt added this to the 1.2.0 milestone Jun 6, 2024
@alvarobartt alvarobartt self-assigned this Jun 6, 2024
@@ -20,23 +20,24 @@
from pydantic import Field, PrivateAttr, SecretStr

try:
import argilla as rg
import argilla_sdk as rg
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@frascuchon What do you think the timeline on renaming argilla_sdk >> argilla is? Is it worth waiting for the changes in argilla so that it's in line?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes are minimal. We can leave it as is until we havemore stable versions of argilla v2.

Copy link
Contributor

@burtenshaw burtenshaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice start. Just a few high level comments.

src/distilabel/steps/argilla/base.py Outdated Show resolved Hide resolved
src/distilabel/steps/argilla/text_generation.py Outdated Show resolved Hide resolved
@alvarobartt
Copy link
Member Author

alvarobartt commented Jun 6, 2024

Edit: the issue was with the Argilla Server version as I was using 1.26.0 while 1.27.0 or higher was required 👍🏻

As an update @burtenshaw @frascuchon I've installed argilla_sdk from the latest version of argilla-python in main and running the code below, leads to the following exception, complaining about not finding the /records/bulk endpoint, could you try to reproduce on your end?


Install as pip install "distilabel[hf-inference-endpoints] @ git+https://github.com/argilla-io/[email protected]" and then run the code below with your personal HF_TOKEN and your Argilla credentials (feel free to use dev):

from uuid import uuid4

from distilabel.llms import InferenceEndpointsLLM
from distilabel.pipeline.local import Pipeline
from distilabel.steps import (
    LoadDataFromDicts,
    TextGenerationToArgilla,
)
from distilabel.steps.tasks import TextGeneration

if __name__ == "__main__":
    with Pipeline(name="my-pipeline") as pipeline:
        load_dataset = LoadDataFromDicts(
            name="load_dataset",
            data=[
                {
                    "instruction": "Write a short story about a dragon that saves a princess from a tower.",
                },
            ],
        )

        text_generation = TextGeneration(
            name="text_generation",
            llm=InferenceEndpointsLLM(
                model_id="meta-llama/Meta-Llama-3-8B-Instruct",
                tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
                api_key="...",  # type: ignore
            ),
            num_generations=4,
            group_generations=True,
        )

        text_generation_to_argilla = TextGenerationToArgilla(
            name="text_generation_to_argilla",
            api_url="...",
            api_key="...",  # type: ignore
            dataset_name=f"text-generation-{uuid4()}",
            dataset_workspace="admin",
        )

        (  # type: ignore
            load_dataset
            >> text_generation
            >> text_generation_to_argilla
        )

    pipeline.run(
        parameters={
            text_generation.name: {  # type: ignore
                "llm": {
                    "generation_kwargs": {
                        "max_new_tokens": 512,
                        "temperature": 0.7,
                    },
                },
            },
        }
    )

The logs then look like:

image

@alvarobartt alvarobartt marked this pull request as ready for review June 7, 2024 11:11
Comment on lines +121 to +131
return (
True
if self._client.datasets( # type: ignore
name=self.dataset_name, # type: ignore
workspace=self._client.workspaces(name=self.dataset_workspace) # type: ignore
if self.dataset_workspace is not None
else None,
).exists()
is not None
else False
)
Copy link
Member

@frascuchon frascuchon Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataset.exists() should return already a bool value. So maybe extra checks are not needed.

Suggested change
return (
True
if self._client.datasets( # type: ignore
name=self.dataset_name, # type: ignore
workspace=self._client.workspaces(name=self.dataset_workspace) # type: ignore
if self.dataset_workspace is not None
else None,
).exists()
is not None
else False
)
workspace = (
self._client.workspaces(name=self.dataset_workspace) # type: ignore
if self.dataset_workspace is not None
else None
)
return (
self._client.datasets( # type: ignore
name=self.dataset_name, # type: ignore
workspace=workspace,
).exists()
)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't do it that way since it was returning None not sure if that is still a thing within the version in the argilla repository

],
MockDataset = rg.Dataset(
name="dataset",
workspace=rg.Workspace("workspace"), # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@burtenshaw I think this is a side effect of calling API on Dataset.__init__. Maybe we should review to force client-dependant actions only for methods calling the API and not init ones (for example workspace is only needed when creating the dataset)

@alvarobartt alvarobartt modified the milestones: 1.2.0, 1.3.0 Jun 11, 2024
Base automatically changed from develop to main June 18, 2024 12:36
@gabrielmbmb gabrielmbmb changed the base branch from main to develop June 19, 2024 11:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

[FEATURE] Upgrade distilabel to use Argilla 2.0
3 participants