diff --git a/README.md b/README.md index 5f7edde..563e7bc 100755 --- a/README.md +++ b/README.md @@ -76,8 +76,16 @@ pip install -r requirements.txt pip install -r questions/inference_server/model-requirements.txt pip install -r dev-requirements.txt pip install -r requirements-test.txt +pip install vllm # optional, enables faster inference ``` +When installing `vllm` for the Smol models we use, the project recommends +installing prebuilt kernels when available to get the best performance. Refer to +the [vLLM installation guide](https://github.com/vllm-project/vllm#installation) +for the latest GPU optimized wheels. + +Set `USE_VLLM=0` to force the server to skip vLLM even when installed. + Using cuda is important to speed up inference. ```shell diff --git a/dev-requirements.txt b/dev-requirements.txt index 1ca915f..6bf53e9 100755 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -5,3 +5,4 @@ pytest-asyncio requests-futures gradio black +vllm diff --git a/questions/inference_server/inference_server.py b/questions/inference_server/inference_server.py index 88278a8..7c24ce4 100644 --- a/questions/inference_server/inference_server.py +++ b/questions/inference_server/inference_server.py @@ -54,6 +54,8 @@ fast_inference, fast_feature_extract_inference, ) +from questions.vllm_inference import VLLM_AVAILABLE, vllm_inference +USE_VLLM = os.getenv("USE_VLLM", "1") == "1" from questions.utils import log_time from sellerinfo import session_secret from .models import build_model @@ -967,7 +969,10 @@ async def generate_route( # status_code=401, detail="Please subscribe at https://text-generator.io/subscribe first" # ) # todo validate api key and user - inference_result = fast_inference(generate_params, MODEL_CACHE) + if VLLM_AVAILABLE and USE_VLLM: + inference_result = vllm_inference(generate_params, weights_path_tgz) + else: + inference_result = fast_inference(generate_params, MODEL_CACHE) # todo vuln if request and background_tasks: if ( @@ -1347,3 +1352,4 @@ def tts_demo(request: Request): # return HTTPException(status_code=500, detail=f"Error generating text: {str(e)}") if __name__ == "__main__": + pass diff --git a/questions/vllm_inference.py b/questions/vllm_inference.py new file mode 100644 index 0000000..b188bdc --- /dev/null +++ b/questions/vllm_inference.py @@ -0,0 +1,75 @@ +import math +from typing import List + +try: + from vllm import LLM, SamplingParams + from vllm.utils import logits_to_probs + VLLM_AVAILABLE = True +except Exception: + from typing import Any + LLM = Any # type: ignore + SamplingParams = Any # type: ignore + logits_to_probs = None # type: ignore + VLLM_AVAILABLE = False + +from nltk.tokenize import sent_tokenize + +from questions.models import GenerateParams +from questions.fixtures import set_stop_reason + +VLLM_MODEL = None + + +def load_vllm(model_path: str): + global VLLM_MODEL + if VLLM_MODEL is None: + if not VLLM_AVAILABLE: + raise RuntimeError("vLLM is not installed") + VLLM_MODEL = LLM(model=model_path, dtype="auto") + return VLLM_MODEL + + +def vllm_inference(generate_params: GenerateParams, model_path: str) -> List[dict]: + if not VLLM_AVAILABLE: + raise RuntimeError("vLLM is not installed") + llm = load_vllm(model_path) + + sampling_params = SamplingParams( + n=generate_params.number_of_results, + temperature=generate_params.temperature, + top_p=generate_params.top_p, + top_k=generate_params.top_k, + max_tokens=generate_params.max_length, + stop=generate_params.stop_sequences, + logprobs=1, + ) + + results = [] + outputs = llm.generate([generate_params.text], sampling_params) + output = outputs[0] + for seq in output.outputs: + text = seq.text + logprobs = seq.logprobs or [] + stop_reason = "max_length" + + if generate_params.min_probability: + cumulative = 1.0 + cut_idx = None + for i, lp in enumerate(logprobs): + cumulative *= math.exp(lp) + if cumulative < generate_params.min_probability: + cut_idx = i + stop_reason = "min_probability" + break + if cut_idx is not None: + words = text.split() + text = " ".join(words[: cut_idx + 1]) + + if generate_params.max_sentences: + sentences = sent_tokenize(text) + if len(sentences) > generate_params.max_sentences: + text = " ".join(sentences[: generate_params.max_sentences]) + stop_reason = "max_sentences" + + results.append({"generated_text": text, "stop_reason": stop_reason}) + return results diff --git a/requirements-test.txt b/requirements-test.txt index ee8652b..f0957b3 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -4,3 +4,4 @@ numpy ruff==0.11.10 httpx colorama +vllm diff --git a/requirements.in b/requirements.in index a996c45..343152a 100755 --- a/requirements.in +++ b/requirements.in @@ -51,3 +51,4 @@ httpcore pillow pyppeteer markitdown[all]==0.1.2 +vllm diff --git a/requirements.txt b/requirements.txt index 7f19911..596f13b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -339,3 +339,4 @@ youtube-transcript-api==1.0.3 # via markitdown zipp==3.22.0 # via importlib-metadata +vllm diff --git a/scripts/run_vllm.py b/scripts/run_vllm.py new file mode 100755 index 0000000..d2fe8eb --- /dev/null +++ b/scripts/run_vllm.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +"""Simple CLI for running vLLM inference.""" +import argparse +from questions.vllm_inference import vllm_inference, VLLM_AVAILABLE +from questions.models import GenerateParams +from questions.constants import weights_path_tgz + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate text using vLLM") + parser.add_argument("text", help="prompt text") + parser.add_argument("--max-length", type=int, default=100) + args = parser.parse_args() + + params = GenerateParams(text=args.text, max_length=args.max_length) + if not VLLM_AVAILABLE: + raise SystemExit("vLLM is not installed") + result = vllm_inference(params, weights_path_tgz)[0] + print(result["generated_text"]) + + +if __name__ == "__main__": + main() diff --git a/tests/unit/test_vllm_env.py b/tests/unit/test_vllm_env.py new file mode 100644 index 0000000..2679ec4 --- /dev/null +++ b/tests/unit/test_vllm_env.py @@ -0,0 +1,16 @@ +import os +import importlib +import pytest + +from questions.vllm_inference import VLLM_AVAILABLE + + +def test_use_vllm_env_flag(monkeypatch): + monkeypatch.setenv("USE_VLLM", "0") + try: + module = importlib.reload( + importlib.import_module("questions.inference_server.inference_server") + ) + except ModuleNotFoundError: + pytest.skip("torch or other deps missing") + assert (module.USE_VLLM is False) or (not VLLM_AVAILABLE) diff --git a/tests/unit/test_vllm_inference.py b/tests/unit/test_vllm_inference.py new file mode 100644 index 0000000..f1a7e77 --- /dev/null +++ b/tests/unit/test_vllm_inference.py @@ -0,0 +1,15 @@ +import pytest +from questions.vllm_inference import VLLM_AVAILABLE, vllm_inference +from questions.models import GenerateParams + + +def test_vllm_import_flag(): + assert isinstance(VLLM_AVAILABLE, bool) + + +def test_vllm_inference_raises_when_missing(): + if VLLM_AVAILABLE: + pytest.skip("vLLM installed, skip missing test") + params = GenerateParams(text="hi") + with pytest.raises(RuntimeError): + vllm_inference(params, "models/SmolLM-1.7B")