-
Notifications
You must be signed in to change notification settings - Fork 49
Expand file tree
/
Copy pathllm_processor.py
More file actions
95 lines (85 loc) · 3.85 KB
/
llm_processor.py
File metadata and controls
95 lines (85 loc) · 3.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
from abc import ABC, abstractmethod
import google.generativeai as genai
from openai import OpenAI, AsyncOpenAI
from typing import AsyncGenerator, Generator, Optional
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class LLMProcessor(ABC):
@abstractmethod
async def process_text(self, text: str, prompt: str, model: Optional[str] = None) -> AsyncGenerator[str, None]:
pass
@abstractmethod
def process_text_sync(self, text: str, prompt: str, model: Optional[str] = None) -> str:
pass
class GeminiProcessor(LLMProcessor):
def __init__(self, default_model: str = 'gemini-1.5-pro'):
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
raise EnvironmentError("GOOGLE_API_KEY is not set")
genai.configure(api_key=api_key)
self.default_model = default_model
async def process_text(self, text: str, prompt: str, model: Optional[str] = None) -> AsyncGenerator[str, None]:
all_prompt = f"{prompt}\n\n{text}"
model_name = model or self.default_model
logger.info(f"Using model: {model_name} for processing")
logger.debug(f"Prompt (length={len(all_prompt)} chars)")
genai_model = genai.GenerativeModel(model_name)
response = await genai_model.generate_content_async(
all_prompt,
stream=True
)
async for chunk in response:
if chunk.text:
yield chunk.text
def process_text_sync(self, text: str, prompt: str, model: Optional[str] = None) -> str:
all_prompt = f"{prompt}\n\n{text}"
model_name = model or self.default_model
logger.info(f"Using model: {model_name} for sync processing")
logger.debug(f"Prompt (length={len(all_prompt)} chars)")
genai_model = genai.GenerativeModel(model_name)
response = genai_model.generate_content(all_prompt)
return response.text
class GPTProcessor(LLMProcessor):
def __init__(self):
if not os.getenv("OPENAI_API_KEY"):
raise ValueError("OpenAI API key not found in environment variables")
self.async_client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
self.sync_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
self.default_model = "gpt-4"
async def process_text(self, text: str, prompt: str, model: Optional[str] = None) -> AsyncGenerator[str, None]:
all_prompt = f"{prompt}\n\n{text}"
model_name = model or self.default_model
logger.info(f"Using model: {model_name} for processing")
logger.debug(f"Prompt (length={len(all_prompt)} chars)")
response = await self.async_client.chat.completions.create(
model=model_name,
messages=[
{"role": "user", "content": all_prompt}
],
stream=True
)
async for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
def process_text_sync(self, text: str, prompt: str, model: Optional[str] = None) -> str:
all_prompt = f"{prompt}\n\n{text}"
model_name = model or self.default_model
logger.info(f"Using model: {model_name} for sync processing")
logger.debug(f"Prompt (length={len(all_prompt)} chars)")
response = self.sync_client.chat.completions.create(
model=model_name,
messages=[
{"role": "user", "content": all_prompt}
]
)
return response.choices[0].message.content
def get_llm_processor(model: str) -> LLMProcessor:
model = model.lower()
if model.startswith(('gemini', 'gemini-')):
return GeminiProcessor(default_model=model)
elif model.startswith(('gpt-', 'o1-')):
return GPTProcessor()
else:
raise ValueError(f"Unsupported model type: {model}")