Spaces:
Running
Running
| import os | |
| from contextlib import contextmanager, redirect_stderr, redirect_stdout | |
| from typing import List | |
| from tenacity import retry, stop_after_attempt, wait_random_exponential | |
| from ..message import SYSTEM_NAME as SYSTEM | |
| from ..message import Message | |
| from .base import IntelligenceBackend | |
| def suppress_stdout_stderr(): | |
| """A context manager that redirects stdout and stderr to devnull.""" | |
| with open(os.devnull, "w") as fnull: | |
| with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out: | |
| yield (err, out) | |
| with suppress_stdout_stderr(): | |
| # Try to import the transformers package | |
| try: | |
| import transformers | |
| from transformers import pipeline | |
| from transformers.pipelines.conversational import ( | |
| Conversation, | |
| ConversationalPipeline, | |
| ) | |
| except ImportError: | |
| is_transformers_available = False | |
| else: | |
| is_transformers_available = True | |
| class TransformersConversational(IntelligenceBackend): | |
| """Interface to the Transformers ConversationalPipeline.""" | |
| stateful = False | |
| type_name = "transformers:conversational" | |
| def __init__(self, model: str, device: int = -1, **kwargs): | |
| super().__init__(model=model, device=device, **kwargs) | |
| self.model = model | |
| self.device = device | |
| assert is_transformers_available, "Transformers package is not installed" | |
| self.chatbot = pipeline( | |
| task="conversational", model=self.model, device=self.device | |
| ) | |
| def _get_response(self, conversation): | |
| conversation = self.chatbot(conversation) | |
| response = conversation.generated_responses[-1] | |
| return response | |
| def _msg_template(agent_name, content): | |
| return f"[{agent_name}]: {content}" | |
| def query( | |
| self, | |
| agent_name: str, | |
| role_desc: str, | |
| history_messages: List[Message], | |
| global_prompt: str = None, | |
| request_msg: Message = None, | |
| *args, | |
| **kwargs, | |
| ) -> str: | |
| user_inputs, generated_responses = [], [] | |
| all_messages = ( | |
| [(SYSTEM, global_prompt), (SYSTEM, role_desc)] | |
| if global_prompt | |
| else [(SYSTEM, role_desc)] | |
| ) | |
| for msg in history_messages: | |
| all_messages.append((msg.agent_name, msg.content)) | |
| if request_msg: | |
| all_messages.append((SYSTEM, request_msg.content)) | |
| prev_is_user = False # Whether the previous message is from the user | |
| for i, message in enumerate(all_messages): | |
| if i == 0: | |
| assert ( | |
| message[0] == SYSTEM | |
| ) # The first message should be from the system | |
| if message[0] != agent_name: | |
| if not prev_is_user: | |
| user_inputs.append(self._msg_template(message[0], message[1])) | |
| else: | |
| user_inputs[-1] += "\n" + self._msg_template(message[0], message[1]) | |
| prev_is_user = True | |
| else: | |
| if prev_is_user: | |
| generated_responses.append(message[1]) | |
| else: | |
| generated_responses[-1] += "\n" + message[1] | |
| prev_is_user = False | |
| assert len(user_inputs) == len(generated_responses) + 1 | |
| past_user_inputs = user_inputs[:-1] | |
| new_user_input = user_inputs[-1] | |
| # Recreate a conversation object from the history messages | |
| conversation = Conversation( | |
| text=new_user_input, | |
| past_user_inputs=past_user_inputs, | |
| generated_responses=generated_responses, | |
| ) | |
| # Get the response | |
| response = self._get_response(conversation) | |
| return response | |
| # conversation = Conversation("Going to the movies tonight - any suggestions?") | |
| # | |
| # # Steps usually performed by the model when generating a response: | |
| # # 1. Mark the user input as processed (moved to the history) | |
| # conversation.mark_processed() | |
| # # 2. Append a mode response | |
| # conversation.append_response("The Big lebowski.") | |
| # | |
| # conversation.add_user_input("Is it good?") | |