Python Client
LoRAX Python client provides a convenient way of interfacing with a
lorax
instance running in your environment.
Install
pip install lorax-client
Usage
from lorax import Client
endpoint_url = "http://127.0.0.1:8080"
client = Client(endpoint_url)
text = client.generate("Why is the sky blue?", adapter_id="some/adapter").generated_text
print(text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
for response in client.generate_stream("Why is the sky blue?", adapter_id="some/adapter"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
or with the asynchronous client:
from lorax import AsyncClient
endpoint_url = "http://127.0.0.1:8080"
client = AsyncClient(endpoint_url)
response = await client.generate("Why is the sky blue?", adapter_id="some/adapter")
print(response.generated_text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
async for response in client.generate_stream("Why is the sky blue?", adapter_id="some/adapter"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
See API reference for full details.
Batch Inference
In some cases you may have a list of prompts that you wish to process in bulk ("batch processing").
Rather than process each prompt one at a time, you can take advantage of the AsyncClient
and LoRAX's native
parallelism to submit your prompts at once and await the results:
import asyncio
import time
from lorax import AsyncClient
# Batch of prompts to submit
prompts = [
"The quick brown fox",
"The rain in Spain",
"What comes up",
]
# Initialize the async client
endpoint_url = "http://127.0.0.1:8080"
async_client = AsyncClient(endpoint_url)
# Submit all prompts and do not block on the response
t0 = time.time()
futures = []
for prompt in prompts:
resp = async_client.generate(prompt, max_new_tokens=64)
futures.append(resp)
# Await the completion of all the prompt requests
responses = await asyncio.gather(*futures)
# Print responses
# Responses will always come back in the same order as the original list
for resp in responses:
print(resp.generated_text)
# Print duration to process all requests in batch
print("duration (s):", time.time() - t0)
Output:
duration (s): 2.9093329906463623
Compare this against the duration of submitting one at a time. You should find that for 3 prompts the duration of async is about 2.5 - 3x faster than serial processing:
from lorax import Client
client = Client(endpoint_url)
t0 = time.time()
responses = []
for prompt in prompts:
resp = client.generate(prompt, max_new_tokens=64)
responses.append(resp)
for resp in responses:
print(resp.generated_text)
print("duration (s):", time.time() - t0)
Output:
duration (s): 8.385080099105835
Predibase Inference Endpoints
The LoRAX client can also be used to connect to Predibase managed LoRAX endpoints (including Predibase's serverless endpoints).
You need only make the following changes to the above examples:
- Change the
endpoint_url
to match the endpoint of your Predibase LLM of choice. - Provide your Predibase API token in the
headers
provided to the client.
Example:
from lorax import Client
# You can get your Predibase API token by going to Settings > My Profile > Generate API Token
# You can get your Predibase Tenant short code by going to Settings > My Profile > Overview > Tenant ID
endpoint_url = f"https://serving.app.predibase.com/{predibase_tenant_short_code}/deployments/v2/llms/{llm_deployment_name}"
headers = {
"Authorization": f"Bearer {api_token}"
}
client = Client(endpoint_url, headers=headers)
# same as above from here ...
response = client.generate("Why is the sky blue?", adapter_id=f"{model_repo}/{model_version}")
Note that by default Predibase will use its internal model repos as the default adapter_source
. To use an adapter from Huggingface:
response = client.generate("Why is the sky blue?", adapter_id="some/adapter", adapter_source="hub")