Spaces:
Running
Running
| from typing import List, Union, Optional, Literal | |
| import dataclasses | |
| from tenacity import ( | |
| retry, | |
| stop_after_attempt, # type: ignore | |
| wait_random_exponential, # type: ignore | |
| ) | |
| import openai | |
| MessageRole = Literal["system", "user", "assistant"] | |
| class Message(): | |
| role: MessageRole | |
| content: str | |
| def message_to_str(message: Message) -> str: | |
| return f"{message.role}: {message.content}" | |
| def messages_to_str(messages: List[Message]) -> str: | |
| return "\n".join([message_to_str(message) for message in messages]) | |
| def gpt_completion( | |
| model: str, | |
| prompt: str, | |
| max_tokens: int = 1024, | |
| stop_strs: Optional[List[str]] = None, | |
| temperature: float = 0.0, | |
| num_comps=1, | |
| ) -> Union[List[str], str]: | |
| response = openai.Completion.create( | |
| model=model, | |
| prompt=prompt, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| top_p=1, | |
| frequency_penalty=0.0, | |
| presence_penalty=0.0, | |
| stop=stop_strs, | |
| n=num_comps, | |
| ) | |
| if num_comps == 1: | |
| return response.choices[0].text # type: ignore | |
| return [choice.text for choice in response.choices] # type: ignore | |
| def gpt_chat( | |
| model: str, | |
| messages: List, | |
| max_tokens: int = 1024, | |
| temperature: float = 0.0, | |
| num_comps=1, | |
| ) -> Union[List[str], str]: | |
| try: | |
| response = openai.ChatCompletion.create( | |
| model=model, | |
| messages=[dataclasses.asdict(message) for message in messages], | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=1, | |
| frequency_penalty=0.0, | |
| presence_penalty=0.0, | |
| n=num_comps, | |
| ) | |
| if num_comps == 1: | |
| return response.choices[0].message.content # type: ignore | |
| return [choice.message.content for choice in response.choices] # type: ignore | |
| except Exception as e: | |
| print(f"An error occurred while calling OpenAI: {e}") | |
| raise | |
| class ModelBase(): | |
| def __init__(self, name: str): | |
| self.name = name | |
| self.is_chat = False | |
| def __repr__(self) -> str: | |
| return f'{self.name}' | |
| def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]: | |
| raise NotImplementedError | |
| def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0.0, num_comps=1) -> Union[List[str], str]: | |
| raise NotImplementedError | |
| class GPTChat(ModelBase): | |
| def __init__(self, model_name: str): | |
| self.name = model_name | |
| self.is_chat = True | |
| def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]: | |
| return gpt_chat(self.name, messages, max_tokens, temperature, num_comps) | |
| class GPT4(GPTChat): | |
| def __init__(self): | |
| super().__init__("gpt-4") | |
| class GPT35(GPTChat): | |
| def __init__(self): | |
| super().__init__("gpt-3.5-turbo") | |
| class GPTDavinci(ModelBase): | |
| def __init__(self, model_name: str): | |
| self.name = model_name | |
| def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0, num_comps=1) -> Union[List[str], str]: | |
| return gpt_completion(self.name, prompt, max_tokens, stop_strs, temperature, num_comps) |