Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria | |
| tokenizer = AutoTokenizer.from_pretrained("stabilityai/stable-code-3b", trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "stabilityai/stable-code-3b", | |
| trust_remote_code=True, | |
| torch_dtype="auto" | |
| ) | |
| class StopOnTokens(StoppingCriteria): | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
| stop_ids = [0, 2] | |
| for stop_id in stop_ids: | |
| if input_ids[0][-1] == stop_id: | |
| return True | |
| return False | |
| def chat(message, history): | |
| stop = StopOnTokens() | |
| history = history or [] | |
| inputs = tokenizer(message, return_tensors="pt").to(model.device) | |
| print('generate') | |
| tokens = model.generate( | |
| **inputs, | |
| max_new_tokens=4096, | |
| temperature=0.2, | |
| do_sample=True, | |
| ) | |
| print('decode') | |
| response = tokenizer.decode(tokens[0], skip_special_tokens=True) | |
| history.append((message, response)) | |
| return history, history | |
| iface = gr.Interface( | |
| chat, | |
| ["text", "state"], | |
| ["chatbot", "state"], | |
| allow_flagging="never" | |
| ) | |
| iface.launch() | |