Spaces:
Paused
Paused
| import os | |
| import json | |
| import gradio as gr | |
| import google.generativeai as genai | |
| GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY") | |
| genai.configure(api_key=GOOGLE_API_KEY) | |
| # Set up the model | |
| generation_config = { | |
| "temperature": 0.9, | |
| "top_p": 1, | |
| "top_k": 1, | |
| "max_output_tokens": 2048, | |
| } | |
| safety_settings = [ | |
| {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, | |
| {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, | |
| { | |
| "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_DANGEROUS_CONTENT", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| ] | |
| model = genai.GenerativeModel( | |
| model_name="gemini-pro", | |
| generation_config=generation_config, | |
| safety_settings=safety_settings, | |
| ) | |
| task_description = " You are an SMS (Short Message Service) reader who reads every message that the short message service centre receives and you need to classify each message among the following categories: {}<div>Let the output be a softmax function output giving the probability of message belonging to each category.</div><div>The sum of the probabilities should be 1</div><div>The output must be in JSON format</div>" | |
| def classify_msg(categories, message): | |
| prompt_parts = [ | |
| task_description.format(categories), | |
| f"Message: {message}", | |
| "Category: ", | |
| ] | |
| response = model.generate_content(prompt_parts) | |
| json_response = json.loads( | |
| response.text[response.text.find("{") : response.text.rfind("}") + 1] | |
| ) | |
| return gr.Label(json_response) | |
| def clear_inputs_and_outputs(): | |
| return [None, None, None] | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| <h1 align="center">Multi-language Text Classifier using Gemini Pro</h1> \ | |
| This space uses Gemini Pro in order to classify texts.<br> \ | |
| Depending on the list of categories that you specify, you can have text classifier, a SPAM detector, a sentiment classifier, ... <br><br> \ | |
| <b>For the categories, enter a list of words separated by commas</b><br><br> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| categories = gr.Textbox( | |
| label="Categories", | |
| placeholder="Input the list of categories as comma separated words", | |
| ) | |
| with gr.Row(): | |
| message = gr.Textbox(label="Message", placeholder="Enter Message") | |
| with gr.Row(): | |
| clr_btn = gr.Button(value="Clear", variant="secondary") | |
| csf_btn = gr.Button(value="Classify") | |
| with gr.Column(): | |
| lbl_output = gr.Label(label="Prediction") | |
| clr_btn.click( | |
| fn=clear_inputs_and_outputs, | |
| inputs=[], | |
| outputs=[categories, message, lbl_output], | |
| ) | |
| csf_btn.click( | |
| fn=classify_msg, | |
| inputs=[categories, message], | |
| outputs=[lbl_output], | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["Normal, Promotional, Urgent", "Will you be passing by?"], | |
| ["Spam, Ham", "Plus de 300 % de perte de poids pendant le régime."], | |
| ["Χαρούμενος, Δυστυχισμένος", "Η εξυπηρέτηση σας ήταν απαίσια"], | |
| ["مهم، أقل أهمية ", "خبر عاجل"], | |
| ], | |
| inputs=[categories, message], | |
| outputs=lbl_output, | |
| fn=classify_msg, | |
| cache_examples=True, | |
| ) | |
| demo.queue(api_open=False) | |
| demo.launch(debug=True, share=True, show_api=False) | |