Spaces:
Paused
Paused
| """ | |
| License: | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| In no event shall the authors or copyright holders be liable | |
| for any claim, damages or other liability, whether in an action of contract,otherwise, | |
| arising from, out of or in connection with the software or the use or | |
| other dealings in the software. | |
| Copyright (c) 2024 pi19404. All rights reserved. | |
| Authors: | |
| pi19404 <[email protected]> | |
| """ | |
| """ | |
| Gradio Interface for Shield Gemma LLM Evaluator | |
| This module provides a Gradio interface to interact with the Shield Gemma LLM Evaluator. | |
| It allows users to input JSON data and select various options to evaluate the content | |
| for policy violations. | |
| Functions: | |
| my_inference_function: The main inference function to process input data and return results. | |
| """ | |
| import gradio as gr | |
| from gradio_client import Client | |
| import json | |
| import threading | |
| import os | |
| from collections import OrderedDict | |
| import httpx | |
| API_TOKEN=os.getenv("API_TOKEN") | |
| lock = threading.Lock() | |
| #client = Client("pi19404/ai-worker",hf_token=API_TOKEN) | |
| # Create an OrderedDict to store clients, limited to 15 entries | |
| client_cache = OrderedDict() | |
| MAX_CACHE_SIZE = 15 | |
| def my_inference_function(client,input_data, output_data,mode, max_length, max_new_tokens, model_size): | |
| """ | |
| The main inference function to process input data and return results. | |
| Args: | |
| input_data (str or dict): The input data in JSON format. | |
| mode (str): The mode of operation ("scoring" or "generative"). | |
| max_length (int): The maximum length of the input prompt. | |
| max_new_tokens (int): The maximum number of new tokens to generate. | |
| model_size (str): The size of the model to be used. | |
| Returns: | |
| str: The output data in JSON format. | |
| """ | |
| with lock: | |
| try: | |
| result = client[0].predict( | |
| input_data=input_data, | |
| output_data=output_data, | |
| mode=mode, | |
| max_length=max_length, | |
| max_new_tokens=max_new_tokens, | |
| model_size=model_size, | |
| api_name="/my_inference_function" | |
| ) | |
| print(result) | |
| print("entering return",result) | |
| return result # Pretty-print the JSON | |
| except json.JSONDecodeError: | |
| return json.dumps({"error": "Invalid JSON input"}) | |
| except KeyError: | |
| return json.dumps({"error": "Missing 'input' key in JSON"}) | |
| except ValueError as e: | |
| return json.dumps({"error": str(e)}) | |
| def wake_up_space_with_retries(space_url, token, retries=5, wait_time=10): | |
| """ | |
| Attempt to wake up the Hugging Face Space with retries. | |
| Retries a number of times in case of a delay due to the Space waking up. | |
| :param space_url: The URL of the Hugging Face Space. | |
| :param token: The Hugging Face API token. | |
| :param retries: Number of retries if the Space is sleeping. | |
| :param wait_time: Time to wait between retries (in seconds). | |
| """ | |
| for attempt in range(retries): | |
| try: | |
| print(f"Attempt {attempt + 1} to wake up the Space...") | |
| # Initialize the Gradio Client | |
| client = Client(space_url, hf_token=token, timeout=httpx.Timeout(30.0)) # 30-second timeout | |
| my_inference_function(client,"test input","",scoring,10,10,"2B") | |
| # Make a prediction or call to wake the Space | |
| #result = client.predict("<your_input>") # Replace with actual inputs | |
| print("Space is awake and ready!") | |
| return client | |
| except httpx.ReadTimeout: | |
| print(f"Request timed out on attempt {attempt + 1}. Retrying in {wait_time} seconds...") | |
| time.sleep(wait_time) | |
| except Exception as e: | |
| print(f"An error occurred on attempt {attempt + 1}: {e}") | |
| # Wait before retrying | |
| if attempt < retries - 1: | |
| print(f"Waiting for {wait_time} seconds before retrying...") | |
| print("Space is still not active after multiple attempts.") | |
| return None | |
| #default_client=Client("pi19404/ai-worker", hf_token=API_TOKEN) | |
| default_client=wake_up_space_with_retries("pi19404/ai-worker",API_TOKEN) | |
| def get_client_for_ip(ip_address,x_ip_token): | |
| """ | |
| Retrieve or create a client for the given IP address. | |
| This function implements a caching mechanism to store up to MAX_CACHE_SIZE clients. | |
| If a client for the given IP exists in the cache, it's returned and moved to the end | |
| of the cache (marking it as most recently used). If not, a new client is created, | |
| added to the cache, and the least recently used client is removed if the cache is full. | |
| Args: | |
| ip_address (str): The IP address of the client. | |
| x_ip_token (str): The X-IP-Token header value for the client. | |
| Returns: | |
| Client: A Gradio client instance for the given IP address. | |
| """ | |
| if x_ip_token is None: | |
| x_ip_token=ip_address | |
| #print("ipaddress is ",x_ip_token) | |
| if x_ip_token is None: | |
| new_client=default_client | |
| else: | |
| if x_ip_token in client_cache: | |
| # Move the accessed item to the end (most recently used) | |
| client_cache.move_to_end(x_ip_token) | |
| return client_cache[x_ip_token] | |
| # Create a new client | |
| new_client = Client("pi19404/ai-worker", hf_token=API_TOKEN, headers={"X-IP-Token": x_ip_token}) | |
| # Add to cache, removing oldest if necessary | |
| if len(client_cache) >= MAX_CACHE_SIZE: | |
| client_cache.popitem(last=False) | |
| client_cache[x_ip_token] = new_client | |
| return new_client | |
| def set_client_for_session(request: gr.Request): | |
| """ | |
| Set up a client for the current session and collect request headers. | |
| This function is called when a new session is initiated. It retrieves or creates | |
| a client for the session's IP address and collects all request headers for debugging. | |
| Args: | |
| request (gr.Request): The Gradio request object for the current session. | |
| Returns: | |
| tuple: A tuple containing: | |
| - Client: The Gradio client instance for the session. | |
| - str: A JSON string of all request headers. | |
| """ | |
| # Collect all headers in a dictionary | |
| all_headers = {header: value for header, value in request.headers.items()} | |
| # Print headers to console | |
| print("All request headers:") | |
| print(json.dumps(all_headers, indent=2)) | |
| x_ip_token = request.headers.get('x-ip-token',None) | |
| ip_address = request.client.host | |
| print("ip address is ",ip_address) | |
| client = get_client_for_ip(ip_address,x_ip_token) | |
| # Return both the client and the headers | |
| return client, json.dumps(all_headers, indent=2) | |
| # The "gradio/text-to-image" space is a ZeroGPU space | |
| with gr.Blocks() as demo: | |
| """ | |
| Main Gradio interface setup. | |
| This block sets up the Gradio interface, including: | |
| - A State component to store the client for the session. | |
| - A JSON component to display request headers for debugging. | |
| - Other UI components (not shown in this snippet). | |
| - A load event that calls set_client_for_session when the interface is loaded. | |
| """ | |
| gr.Markdown("## LLM Safety Evaluation") | |
| client = gr.State() | |
| with gr.Tab("ShieldGemma2"): | |
| input_text = gr.Textbox(label="Input Text") | |
| output_text = gr.Textbox( | |
| label="Response Text", | |
| lines=5, | |
| max_lines=10, | |
| show_copy_button=True, | |
| elem_classes=["wrap-text"] | |
| ) | |
| mode_input = gr.Dropdown(choices=["scoring", "generative"], label="Prediction Mode") | |
| max_length_input = gr.Number(label="Max Length", value=150) | |
| max_new_tokens_input = gr.Number(label="Max New Tokens", value=1024) | |
| model_size_input = gr.Dropdown(choices=["2B", "9B", "27B"], label="Model Size") | |
| response_text = gr.Textbox( | |
| label="Output Text", | |
| lines=10, | |
| max_lines=20, | |
| show_copy_button=True, | |
| elem_classes=["wrap-text"] | |
| ) | |
| text_button = gr.Button("Submit") | |
| text_button.click(fn=my_inference_function, inputs=[client,input_text, output_text, mode_input, max_length_input, max_new_tokens_input, model_size_input], outputs=response_text) | |
| # with gr.Tab("API Input"): | |
| # api_input = gr.JSON(label="Input JSON") | |
| # mode_input_api = gr.Dropdown(choices=["scoring", "generative"], label="Mode") | |
| # max_length_input_api = gr.Number(label="Max Length", value=150) | |
| # max_new_tokens_input_api = gr.Number(label="Max New Tokens", value=None) | |
| # model_size_input_api = gr.Dropdown(choices=["2B", "9B", "27B"], label="Model Size") | |
| # api_output = gr.JSON(label="Output JSON") | |
| # api_button = gr.Button("Submit") | |
| # api_button.click(fn=my_inference_function, inputs=[api_input, api_output,mode_input_api, max_length_input_api, max_new_tokens_input_api, model_size_input_api], outputs=api_output) | |
| demo.load(set_client_for_session,None,client) | |
| demo.launch(share=True) | |