File size: 3,040 Bytes
c5b3be4
 
d89516a
f26d471
c5b3be4
 
f26d471
c5b3be4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f26d471
d89516a
f26d471
c5b3be4
f26d471
 
 
 
c5b3be4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f26d471
c5b3be4
 
 
 
 
 
 
 
 
 
 
 
8692da2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
# import spaces
import functools

# Download only the safety model shard from the model repo



# Load the Llama model for safety classification
# model_path = r".\models\llama-guard-3-8b-q4_k_m.gguf"


# llm = Llama(model_path=model_path, n_ctx=1024)
# llm = Llama(
#     model_path=model_path,
#     n_ctx=512,            # down from 4096
#     low_vram=True,         # llama.cpp low-vram mode
#     f16_kv=True,           # half-precision kv cache
#     use_mmap=True,         # mmap file
#     use_mlock=False,
# )

# @spaces.CPU
@functools.lru_cache(maxsize=1)
def llm_gpu():
    model_path = hf_hub_download(
        repo_id="Inventors-Hub/SwarmChat-models",
        repo_type="model",
        filename="llama-guard-3-8b-q4_k_m.gguf",)
        
    llm = Llama(
        model_path=model_path,
        n_ctx=512,            # down from 4096
        low_vram=True,         # llama.cpp low-vram mode
        f16_kv=True,           # half-precision kv cache
        use_mmap=True,         # mmap file
        use_mlock=False,
    )
    return llm

def check_safety(text, checkbox):

    if checkbox:
        return "Safe"

    categories = {
        "S1": "Violent Crimes.",
        "S2": "Non-Violent Crimes.",
        "S3": "Sex-Related Crimes.",
        "S4": "Child Sexual Exploitation.",
        "S5": "Defamation.",
        "S6": "Specialized Advice.",
        "S7": "Privacy.",
        "S8": "Intellectual Property.",
        "S9": "Indiscriminate Weapons.",
        "S10": "Hate.",
        "S11": "Suicide & Self-Harm.",
        "S12": "Sexual Content.",
        "S13": "Elections.",
        "S14": "Code Interpreter Abuse.",
    }

    input_prompt = f"""
    <|begin_of_text|><|start_header_id|>user<|end_header_id|>

    Task: Determine if the following text contains unsafe content according to our safety policy. Classify the content into one or more of the categories below if unsafe.

    <BEGIN UNSAFE CONTENT CATEGORIES>
    {', '.join([f'{key}: {value}' for key, value in categories.items()])}
    <END UNSAFE CONTENT CATEGORIES>

    Text: {text}

    Provide your safety assessment:
    - Respond with "safe" if the text is appropriate.
    - If unsafe, respond with: "unsafe" followed by a comma-separated list of the violated categories (e.g., S1, S2).

    <|eot_id|>
    """

    # Generate a response from the model
    llm = llm_gpu()
    output = llm(input_prompt, temperature=0)
    # Extract the model's response
    response = output.get("choices", [{}])[0].get("text", "").strip()
    print("\n\n", response, "\n\n")
    # Parse the response
    if response.lower().startswith("safe"):
        return "Safe"
    else:
        unsafe_categories = categories[response.split("unsafe", 1)[-1].strip()]
        return f"Unsafe: This prompt is categorized as '{unsafe_categories}'"

        # unsafe_categories = categories[response.split("unsafe", 1)[-1].strip()]
        # return f"Unsafe: This prompt categorized as '{unsafe_categories}'"