Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,16 @@ pip install -r requirements.txt
pip install -r questions/inference_server/model-requirements.txt
pip install -r dev-requirements.txt
pip install -r requirements-test.txt
pip install vllm # optional, enables faster inference
```

When installing `vllm` for the Smol models we use, the project recommends
installing prebuilt kernels when available to get the best performance. Refer to
the [vLLM installation guide](https://github.com/vllm-project/vllm#installation)
for the latest GPU optimized wheels.

Set `USE_VLLM=0` to force the server to skip vLLM even when installed.

Using cuda is important to speed up inference.

```shell
Expand Down
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ pytest-asyncio
requests-futures
gradio
black
vllm
8 changes: 7 additions & 1 deletion questions/inference_server/inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
fast_inference,
fast_feature_extract_inference,
)
from questions.vllm_inference import VLLM_AVAILABLE, vllm_inference
USE_VLLM = os.getenv("USE_VLLM", "1") == "1"
from questions.utils import log_time
from sellerinfo import session_secret
from .models import build_model
Expand Down Expand Up @@ -967,7 +969,10 @@ async def generate_route(
# status_code=401, detail="Please subscribe at https://text-generator.io/subscribe first"
# )
# todo validate api key and user
inference_result = fast_inference(generate_params, MODEL_CACHE)
if VLLM_AVAILABLE and USE_VLLM:
inference_result = vllm_inference(generate_params, weights_path_tgz)
else:
inference_result = fast_inference(generate_params, MODEL_CACHE)
# todo vuln
if request and background_tasks:
if (
Expand Down Expand Up @@ -1347,3 +1352,4 @@ def tts_demo(request: Request):
# return HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")

if __name__ == "__main__":
pass
75 changes: 75 additions & 0 deletions questions/vllm_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import math
from typing import List

try:
from vllm import LLM, SamplingParams
from vllm.utils import logits_to_probs
VLLM_AVAILABLE = True
except Exception:
from typing import Any
LLM = Any # type: ignore
SamplingParams = Any # type: ignore
logits_to_probs = None # type: ignore
VLLM_AVAILABLE = False

from nltk.tokenize import sent_tokenize

from questions.models import GenerateParams
from questions.fixtures import set_stop_reason

VLLM_MODEL = None


def load_vllm(model_path: str):
global VLLM_MODEL
if VLLM_MODEL is None:
if not VLLM_AVAILABLE:
raise RuntimeError("vLLM is not installed")
VLLM_MODEL = LLM(model=model_path, dtype="auto")
return VLLM_MODEL


def vllm_inference(generate_params: GenerateParams, model_path: str) -> List[dict]:
if not VLLM_AVAILABLE:
raise RuntimeError("vLLM is not installed")
llm = load_vllm(model_path)

sampling_params = SamplingParams(
n=generate_params.number_of_results,
temperature=generate_params.temperature,
top_p=generate_params.top_p,
top_k=generate_params.top_k,
max_tokens=generate_params.max_length,
stop=generate_params.stop_sequences,
logprobs=1,
)

results = []
outputs = llm.generate([generate_params.text], sampling_params)
output = outputs[0]
for seq in output.outputs:
text = seq.text
logprobs = seq.logprobs or []
stop_reason = "max_length"

if generate_params.min_probability:
cumulative = 1.0
cut_idx = None
for i, lp in enumerate(logprobs):
cumulative *= math.exp(lp)
if cumulative < generate_params.min_probability:
cut_idx = i
stop_reason = "min_probability"
break
if cut_idx is not None:
words = text.split()
text = " ".join(words[: cut_idx + 1])

if generate_params.max_sentences:
sentences = sent_tokenize(text)
if len(sentences) > generate_params.max_sentences:
text = " ".join(sentences[: generate_params.max_sentences])
stop_reason = "max_sentences"

results.append({"generated_text": text, "stop_reason": stop_reason})
return results
1 change: 1 addition & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ numpy
ruff==0.11.10
httpx
colorama
vllm
1 change: 1 addition & 0 deletions requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@ httpcore
pillow
pyppeteer
markitdown[all]==0.1.2
vllm
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,4 @@ youtube-transcript-api==1.0.3
# via markitdown
zipp==3.22.0
# via importlib-metadata
vllm
23 changes: 23 additions & 0 deletions scripts/run_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/usr/bin/env python
"""Simple CLI for running vLLM inference."""
import argparse
from questions.vllm_inference import vllm_inference, VLLM_AVAILABLE
from questions.models import GenerateParams
from questions.constants import weights_path_tgz


def main() -> None:
parser = argparse.ArgumentParser(description="Generate text using vLLM")
parser.add_argument("text", help="prompt text")
parser.add_argument("--max-length", type=int, default=100)
args = parser.parse_args()

params = GenerateParams(text=args.text, max_length=args.max_length)
if not VLLM_AVAILABLE:
raise SystemExit("vLLM is not installed")
result = vllm_inference(params, weights_path_tgz)[0]
print(result["generated_text"])


if __name__ == "__main__":
main()
16 changes: 16 additions & 0 deletions tests/unit/test_vllm_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os
import importlib
import pytest

from questions.vllm_inference import VLLM_AVAILABLE


def test_use_vllm_env_flag(monkeypatch):
monkeypatch.setenv("USE_VLLM", "0")
try:
module = importlib.reload(
importlib.import_module("questions.inference_server.inference_server")
)
except ModuleNotFoundError:
pytest.skip("torch or other deps missing")
assert (module.USE_VLLM is False) or (not VLLM_AVAILABLE)
15 changes: 15 additions & 0 deletions tests/unit/test_vllm_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest
from questions.vllm_inference import VLLM_AVAILABLE, vllm_inference
from questions.models import GenerateParams


def test_vllm_import_flag():
assert isinstance(VLLM_AVAILABLE, bool)


def test_vllm_inference_raises_when_missing():
if VLLM_AVAILABLE:
pytest.skip("vLLM installed, skip missing test")
params = GenerateParams(text="hi")
with pytest.raises(RuntimeError):
vllm_inference(params, "models/SmolLM-1.7B")
Loading