Spaces:
Build error
Build error
| from typing import List | |
| from huggingface_hub import InferenceClient | |
| from openai import OpenAI | |
| def detect_abstain(text: str, api: str, model: str): | |
| if api == "openai": | |
| client = OpenAI() | |
| elif api == "hf": | |
| client = InferenceClient() | |
| else: | |
| raise ValueError(f"Invalid API: {api}") | |
| detect_abstain_prompt = f""" | |
| You are given a piece of text that is a part of a biography of an entity. | |
| Text: {text} | |
| If the text claims a lack of knowledge about the topic, return "Abstain". | |
| Otherwise, return "Not abstain". | |
| """ | |
| completion = client.chat.completions.create( | |
| model=model, | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": detect_abstain_prompt}, | |
| ], | |
| ) | |
| return completion.choices[0].message.content.strip() | |
| def calculate_factf1_at_k( | |
| supported_facts: List[str], unsupported_facts: List[str], k: int | |
| ) -> float: | |
| """ | |
| Calculate the F1 score at k for supported and unsupported facts | |
| """ | |
| if len(supported_facts) == 0: | |
| return 0 | |
| precision = len(supported_facts) / (len(supported_facts) + len(unsupported_facts)) | |
| recall = min(len(supported_facts) / k, 1) | |
| f1 = 2 * precision * recall / (precision + recall) | |
| return f1 | |