#!/usr/bin/env python3 import os import threading from typing import Any import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer try: import spaces except Exception: class _SpacesFallback: @staticmethod def GPU(duration: int = 60): def _decorator(fn): return fn return _decorator spaces = _SpacesFallback() DEFAULT_FULL_MODEL = "NousResearch/nomos-1" DEFAULT_MODEL_CANDIDATES = "cyankiwi/nomos-1-AWQ-8bit,cyankiwi/nomos-1-AWQ-4bit" GPU_DURATION_SECONDS = int(os.getenv("GPU_DURATION_SECONDS", "120")) MAX_INPUT_TOKENS = int(os.getenv("MAX_INPUT_TOKENS", "2048")) MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "256")) TRUST_REMOTE_CODE = os.getenv("TRUST_REMOTE_CODE", "true").lower() == "true" PREFER_FULL = os.getenv("PREFER_FULL", "false").lower() == "true" _MODEL_LOCK = threading.Lock() _MODEL: Any = None _TOKENIZER: Any = None _MODEL_ID: str | None = None _LOAD_ERRORS: list[str] = [] def _ordered_candidates() -> list[str]: configured = os.getenv("MODEL_CANDIDATES", DEFAULT_MODEL_CANDIDATES) candidates = [m.strip() for m in configured.split(",") if m.strip()] if PREFER_FULL and DEFAULT_FULL_MODEL not in candidates: candidates = [DEFAULT_FULL_MODEL] + candidates return candidates def _load_model_if_needed() -> tuple[str | None, str]: global _MODEL, _TOKENIZER, _MODEL_ID if _MODEL is not None and _TOKENIZER is not None and _MODEL_ID is not None: return _MODEL_ID, "model already loaded" with _MODEL_LOCK: if _MODEL is not None and _TOKENIZER is not None and _MODEL_ID is not None: return _MODEL_ID, "model already loaded" errors: list[str] = [] for candidate in _ordered_candidates(): try: tokenizer = AutoTokenizer.from_pretrained( candidate, trust_remote_code=TRUST_REMOTE_CODE, ) model = AutoModelForCausalLM.from_pretrained( candidate, device_map="auto", trust_remote_code=TRUST_REMOTE_CODE, low_cpu_mem_usage=True, ) model.eval() _TOKENIZER = tokenizer _MODEL = model _MODEL_ID = candidate _LOAD_ERRORS.clear() return candidate, "loaded" except Exception as exc: errors.append(f"{candidate}: {type(exc).__name__}: {exc}") _LOAD_ERRORS[:] = errors return None, "load failed" def _status_text() -> str: candidates = ", ".join(_ordered_candidates()) loaded = _MODEL_ID or "none" base = ( f"Loaded model: `{loaded}`\n\n" f"Candidates: `{candidates}`\n\n" f"GPU duration: `{GPU_DURATION_SECONDS}s` | " f"Max input tokens: `{MAX_INPUT_TOKENS}`" ) if _LOAD_ERRORS: err = "\n".join(f"- {e}" for e in _LOAD_ERRORS[-3:]) return base + "\n\nRecent load errors:\n" + err return base @spaces.GPU(duration=GPU_DURATION_SECONDS) def generate( prompt: str, max_new_tokens: int, temperature: float, top_p: float, top_k: int, do_sample: bool, ) -> tuple[str, str]: prompt = (prompt or "").strip() if not prompt: return "Provide a prompt.", _status_text() model_id, _ = _load_model_if_needed() if model_id is None: return "Model load failed. Check status and Space logs.", _status_text() tokenizer = _TOKENIZER model = _MODEL messages = [{"role": "user", "content": prompt}] input_ids = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", ).to(model.device) if input_ids.shape[-1] > MAX_INPUT_TOKENS: input_ids = input_ids[:, -MAX_INPUT_TOKENS:] gen_kwargs: dict[str, Any] = { "input_ids": input_ids, "max_new_tokens": int(max_new_tokens), "do_sample": bool(do_sample), "pad_token_id": tokenizer.eos_token_id, } if do_sample: gen_kwargs.update( { "temperature": float(temperature), "top_p": float(top_p), "top_k": int(top_k), } ) with torch.no_grad(): output_ids = model.generate(**gen_kwargs) generated_ids = output_ids[0][input_ids.shape[-1]:] text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() if not text: text = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() return text, _status_text() with gr.Blocks(title="Nomos ZeroGPU Inference") as demo: gr.Markdown( "# Nomos Remote Inference (ZeroGPU)\n" "This app tries model candidates in order and keeps the first that loads." ) with gr.Row(): with gr.Column(scale=2): prompt = gr.Textbox( label="Prompt", lines=10, placeholder="Ask for a concise proof or solution sketch...", ) with gr.Row(): max_new_tokens = gr.Slider( minimum=32, maximum=1024, value=MAX_NEW_TOKENS_DEFAULT, step=1, label="Max new tokens", ) top_k = gr.Slider( minimum=1, maximum=100, value=20, step=1, label="Top-k", ) with gr.Row(): temperature = gr.Slider( minimum=0.0, maximum=1.5, value=0.6, step=0.01, label="Temperature", ) top_p = gr.Slider( minimum=0.05, maximum=1.0, value=0.95, step=0.01, label="Top-p", ) do_sample = gr.Checkbox(value=True, label="Sample") run_btn = gr.Button("Generate") with gr.Column(scale=2): output = gr.Textbox(label="Output", lines=18) status = gr.Markdown(value=_status_text()) run_btn.click( fn=generate, inputs=[prompt, max_new_tokens, temperature, top_p, top_k, do_sample], outputs=[output, status], api_name="generate", ) gr.Examples( examples=[ ["Solve: Find all integers n such that n^2 + n + 1 is prime."], ["Give a proof sketch that there are infinitely many primes."], ], inputs=prompt, ) demo.queue(max_size=32) if __name__ == "__main__": demo.launch()