File size: 2,976 Bytes
92e3dc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa8a6c8
92e3dc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from transformers import pipeline
import pandas as pd
import gradio as gr
import os
import copy
import spaces

from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextStreamer, TextIteratorStreamer


# quantization_config = BitsAndBytesConfig(load_in_4bit=True)
torch_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.mps.is_available() else "cpu")

torch_dtype = torch.bfloat16 if torch_device in ["cuda", "mps"] else torch.float32

llama_model=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", 
                                           #  quantization_config=quantization_config, 
                                           torch_dtype=torch_dtype, 
                                           device_map=torch_device)

llama_tokenizer=AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")

# streamer = TextStreamer(llama_tokenizer)

llama32_1b_pipe = pipeline(
    "text-generation",
    model=llama_model,
    tokenizer=llama_tokenizer,
    # streamer = streamer,
)

def context_window_limiting(history: list[dict], context_window: int):
    '''
    cull full messages until you have the desired context length

    TO DO
    '''
    history_windowed = copy.deepcopy(history)
    if len(history_windowed) > 0: #has to be a non-empty list
        # print(history_windowed)
        # print(llama32_1b_pipe.tokenizer.apply_chat_template(history_windowed))
        # print(len(llama32_1b_pipe.tokenizer.apply_chat_template(history_windowed)))
        while len(llama32_1b_pipe.tokenizer.apply_chat_template(history_windowed)) >= context_window:
            if len(history_windowed) <= 0: #has to be a non-empty list
                break
            else:
                del history_windowed[0] #delete first message
    #DEBUG
    print(f"number of messages in chat hist: {len(history_windowed)}")
    return history_windowed

@spaces.GPU
def llama32_1b_chat(message, history, context_window) -> str: 
    "simplifies pipeline output to only return generated text"
    input_history = copy.deepcopy(history)
    input_history.append({"role": "user", "content": message})
    input_history = context_window_limiting(input_history, context_window)
    ##add sth about context window here

    outputs = llama32_1b_pipe(
        input_history,
        max_new_tokens=512
    )
    return outputs[-1]['generated_text'][-1]['content']
    


# Create the Gradio interface
def create_interface():
    
    with gr.Blocks() as demo:
        with gr.Row():
            context_window = gr.Slider(64, 1024, value=256, label="size of context window", info="choose context window size")
        with gr.Row():
            gr.ChatInterface(fn=llama32_1b_chat, additional_inputs = [context_window], type="messages", title="context_window")
    
    return demo

# Launch the app
demo = create_interface()
demo.launch()