# app.py — veureu/schat (Salamandra 7B Instruct · ZeroGPU) — compatible with ENGINE from __future__ import annotations import os, json from typing import List, Dict, Any, Optional, Tuple import gradio as gr import spaces import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, ) from transformers import AutoTokenizer, AutoModelForCausalLM from moe_tools import SalamandraClient # ===== Config ===== MODEL_ID = os.environ.get("MODEL_ID", "BSC-LT/salamandra-7b-instruct") DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" _tok = None _model = None _salamandra = None def _lazy_load() -> Tuple[AutoTokenizer, AutoModelForCausalLM]: global _tok, _model if _tok is None or _model is None: _tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True) _model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=DTYPE, low_cpu_mem_usage=True, use_safetensors=True, trust_remote_code=True, device_map=None, ).to(DEVICE) return _tok, _model def _build_prompt(prompt: str, system: Optional[str]) -> str: """ If the tokenizer has 'chat_template', use it with messages [system?, user]. Otherwise, create a plain prompt with system at the top. """ tok, _ = _lazy_load() messages = [] if system and system.strip(): messages.append({"role": "system", "content": system.strip()}) messages.append({"role": "user", "content": prompt}) chat_template = getattr(tok, "chat_template", None) if chat_template: return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Fallback without chat template sys_part = (f"<>\n{system.strip()}\n<>\n\n" if system and system.strip() else "") return sys_part + f"### Instrucció\n{prompt}\n\n### Resposta\n" #@spaces.GPU # use GPU if available (ZeroGPU) #def _generate_with_tools( # messages: List[Dict[str, str]], # tools: List[Dict[str, Any]], # max_new_tokens: int = 512, # temperature: float = 0.7, # top_p: float = 0.95, #) -> Dict[str, Any]: # tok, model = _lazy_load() # tools_md = _render_tools_md(tools) # prompt = _compose_chat_prompt(messages, tools_md) # inputs = tok(prompt, return_tensors="pt").to(DEVICE) # with torch.inference_mode(): # out = model.generate( # **inputs, # max_new_tokens=int(max_new_tokens), # temperature=float(temperature), # top_p=float(top_p), # do_sample=True if temperature > 0 else False, # pad_token_id=tok.eos_token_id, # eos_token_id=tok.eos_token_id, # ) # text = tok.decode(out[0], skip_special_tokens=True).strip() # # If the model returns a JSON block with 'tool_calls', try to extract it # tool_calls: List[Dict[str, Any]] = [] # try: # # Search for the last {...} containing "tool_calls" # matches = list(re.finditer(r"\{.*?\"tool_calls\".*?\}", text, flags=re.S)) # if matches: # block = text[matches[-1].start():matches[-1].end()] # obj = json.loads(block) # tc = obj.get("tool_calls", []) # if isinstance(tc, list): # tool_calls = tc # except Exception: # pass # Execute the extracted tool calls if any # tool_results = maybe_execute_tool_calls(tool_calls) if tool_calls else [] # return {"text": text, "tool_calls": tool_calls, "tool_results": tool_results} @spaces.GPU # use GPU if available (ZeroGPU) def _generate( prompt: str, system: str = "", max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95, ) -> str: tok, model = _lazy_load() text = _build_prompt(prompt, system or "") inputs = tok(text, return_tensors="pt").to(DEVICE) with torch.inference_mode(): out = model.generate( **inputs, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=float(top_p), do_sample=True if temperature > 0 else False, pad_token_id=tok.eos_token_id, eos_token_id=tok.eos_token_id, ) return tok.decode(out[0], skip_special_tokens=True).strip() # ------------------- Gradio Endpoints ------------------- # 1) /predict — what ENGINE expects (only 'prompt' → string) def predict_for_engine(prompt: str) -> str: return _generate(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95) # 2) /generate — more controls (prompt + system + params) def generate_advanced(prompt: str, system: str, max_new_tokens: int, temperature: float, top_p: float) -> str: return _generate(prompt=prompt, system=system, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p) def salamandra_chat_endpoint(prompt: str) -> Dict[str, Any]: global _salamandra if _salamandra is None: _salamandra = SalamandraClient() # use your class try: text = _salamandra.chat(prompt) except Exception as e: text = f"Error running SalamandraClient: {str(e)}" return {"text": text} def resume_sentence(sentence, num_words): """ Summarizes the given sentence in the specified number of words. Parameters: - sentence (str): The sentence to summarize. - num_words (int): The number of words for the summary. Returns: - str: The summarized sentence. """ num_words = int(num_words) # Prompt the model to summarize the sentence prompt = f"Instrució: Resumeix la següent frase en {num_words} paraules. Input: {sentence}" result = generate_advanced(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95) # Clean the output if it contains 'assistant' role if "assistant" in result: clean_output = result.split("assistant", 1)[1].strip().split("\n")[0] else: clean_output = sentence return clean_output def identity_manager(sentence, person): """ Replaces the subject of the sentence with the indicated person, keeping the rest unchanged. """ prompt = f"""Instrucció: Substitueix el subjecte de la frase per la persona indicada, mantenint la resta igual. Frase: {sentence} Substitució: {person} Resposta:""" # Generate the modified sentence using the advanced generator result = generate_advanced(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95) # Clean the output if it contains 'assistant' role if "assistant" in result: clean_output = result.split("assistant", 1)[1].strip().split("\n")[0] else: clean_output = sentence return clean_output def free_narration(srt_text): """ Converts the given audio description into a short, natural, and coherent free narration. """ prompt = f"""Instrucció: Converteix aquesta audiodescripció en una narració lliure breu, natural i coherent., input: {srt_text} output: """ # Generate the free narration using the advanced generator result = generate_advanced(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95) # Clean the output if it contains 'assistant' role if "assistant" in result: clean_output = result.split("assistant", 1)[1].strip().split("\n")[0] else: clean_output = srt_text # fallback to original input return clean_output # ------------------- HTTP (opcional, clientes puros) ------------------- # Si quieres, puedes añadir un endpoint HTTP POST /generate (FastAPI), # pero con Gradio Client es suficiente para engine/local. # ------------------- UI ------------------- custom_css = """ h2 { background: #e3e4e6 !important; padding: 14px 22px !important; border-radius: 14px !important; box-shadow: 0 4px 12px rgba(0,0,0,0.08) !important; display: block !important; /* ocupa tot l'ample */ width: 100% !important; /* assegura 100% */ margin: 20px auto !important; text-align:center; } """ # App UI built with Gradio. This interface exposes several model utilities. with gr.Blocks(title="Salamandra 7B Instruct · ZeroGPU", css=custom_css, theme=gr.themes.Soft()) as demo: # Section: Instruction-based text generation gr.Markdown("## Salamandra-7B-Instruct · ZeroGPU\nText → resposta instruccional.") with gr.Row(): with gr.Column(scale=1): # System prompt (optional internal conditioning) in_system = gr.Textbox(label="Sistema (opcional)", value="") # User prompt to instruct the model in_prompt = gr.Textbox(label="Instrucció", placeholder="Escriu la teva instrucció…", lines=6) # Maximum number of new tokens to generate max_new = gr.Slider(16, 2048, value=512, step=16, label="Màxim de tokens nous") # Diversity parameter for randomness temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperatura") # Nucleus sampling threshold top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p") # Button to trigger text generation btn = gr.Button("Generar", variant="primary") with gr.Column(scale=1): # Output box for generated text out = gr.Textbox(label="Resposta", lines=18) # Bind main generation function btn.click( generate_advanced, [in_prompt, in_system, max_new, temp, top_p], out, api_name="generate", concurrency_limit=1 ) # -------------------------------------------------------------- gr.Markdown("---") # -------------------------------------------------------------- # Minimal endpoint for ENGINE compatibility (/predict) # Only requires a prompt, returns generated text in_prompt_engine = gr.Textbox(label="Instrucció (ENGINE)", value="Digues hola en una frase.") out_engine = gr.Textbox(label="Resposta (ENGINE)") gr.Button("Provar /predict").click( predict_for_engine, [in_prompt_engine], out_engine, api_name="predict", concurrency_limit=1 ) # -------------------------------------------------------------- gr.Markdown("---") # -------------------------------------------------------------- # Section: Sentence summarization gr.Markdown('

Resumir frases

') with gr.Row(): with gr.Column(scale=1): # Text to summarize sentence = gr.Textbox(label="Frase a resumir", value="", lines=3) # Desired number of words in the summary num_words = gr.Textbox(label="Nombre de paraules del resum", value="4") with gr.Column(scale=1): # Output summary out_resume = gr.Textbox(label="Resposta", lines=18) with gr.Row(): # Button to produce a summary btn_resume = gr.Button("Resumir", variant="primary") btn_resume.click( resume_sentence, inputs=[sentence, num_words], outputs=out_resume, api_name="resume", concurrency_limit=1 ) # -------------------------------------------------------------- gr.Markdown("---") # -------------------------------------------------------------- # Section: Inclusion of identities inside text gr.Markdown('

Inclusió d’identitats

') with gr.Row(): with gr.Column(scale=1): # Sentence to modify sentence = gr.Textbox(label="Frase a modificar", value="", lines=3) # Identity mapping provided by the user person = gr.Textbox(label="Persones reconegudes", value='"Mireia Martí": 4, "Xavier Busquets": 5') with gr.Column(scale=1): out_modificat = gr.Textbox(label="Resposta", lines=18) with gr.Row(): btn_modify = gr.Button("Modificar frase", variant="primary") btn_modify.click( identity_manager, inputs=[sentence, person], outputs=out_modificat, api_name="modificat", concurrency_limit=1 ) # -------------------------------------------------------------- gr.Markdown("---") # -------------------------------------------------------------- # Section: Free narration generation from SRT-like audio description gr.Markdown('

Narració lliure

') with gr.Row(): with gr.Column(scale=1): # SRT-like structured description srt = gr.Textbox( label="Audiodescripció", value="(AD)\nTOTS CANTANT: avui celebrem la nostra festa major\nAINA: som hi tots a ballar", lines=3 ) btn_modify = gr.Button("Generar narració lliure", variant="primary") with gr.Column(scale=1): narració_lliure = gr.Textbox(label="Narració lliure", lines=18) btn_modify.click( free_narration, inputs=[srt], outputs=narració_lliure, api_name="narració", concurrency_limit=1 ) # -------------------------------------------------------------- gr.Markdown("---") # -------------------------------------------------------------- # Section: Raw model output from a prompt (JSON) gr.Markdown('

Sortida del model Salamandra a partir d’una petició

') with gr.Row(): prompt = gr.Textbox(label="Prompt", lines=10) with gr.Row(): btn2 = gr.Button("Generar", variant="primary") with gr.Row(): out2 = gr.JSON(label="Sortida") btn2.click( salamandra_chat_endpoint, [prompt], out2, api_name="generate_out_from_prompt", concurrency_limit=1 ) # -------------------------------------------------------------- gr.Markdown("---") # -------------------------------------------------------------- # Queue to handle multiple requests safely demo.queue(max_size=16).launch()