import gc, json, torch, gradio as gr from huggingface_hub import hf_hub_download import tiktoken from mingpt.model import GPT DEVICE = "cuda" if torch.cuda.is_available() else "cpu" REPO_ID = "to0ony/final-thesis-plotgen" state = {"model": None, "enc": tiktoken.get_encoding("gpt2")} def load_model(): """Lazy-load model iz Hugging Face repozitorija""" if state["model"] is not None: return state["model"] # skinuti config i model.pt cfg_path = hf_hub_download(repo_id=REPO_ID, filename="config.json") mdl_path = hf_hub_download(repo_id=REPO_ID, filename="model.pt") # učitati config with open(cfg_path, "r", encoding="utf-8") as f: cfg = json.load(f) gcfg = GPT.get_default_config() gcfg.vocab_size = cfg["vocab_size"] gcfg.block_size = cfg["block_size"] gcfg.n_layer = cfg["n_layer"] gcfg.n_head = cfg["n_head"] gcfg.n_embd = cfg["n_embd"] model = GPT(gcfg) sd = torch.load(mdl_path, map_location="cpu") model.load_state_dict(sd, strict=True) model.to(DEVICE) model.eval() state["model"] = model return model @torch.inference_mode() def generate(prompt, max_new_tokens=200, temperature=0.9, top_k=50): """Generiranje teksta iz prompta""" model = load_model() enc = state["enc"] x = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=DEVICE) y = model.generate( x, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_k=int(top_k) if top_k > 0 else None ) return enc.decode(y[0].tolist()) # Gradio UI with gr.Blocks(title="🎬 Final Thesis Plot Generator") as demo: gr.Markdown("## 🎬 Film Plot Generator\nUnesi prompt i generiraj radnju filma.") prompt = gr.Textbox(label="Prompt", lines=5, placeholder="E.g. A young detective arrives in a coastal town...") max_new_tokens = gr.Slider(32, 512, value=200, step=16, label="Max new tokens") temperature = gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Temperature") top_k = gr.Slider(0, 100, value=50, step=5, label="Top-K (0 = off)") btn = gr.Button("Generate") output = gr.Textbox(label="Output", lines=15) btn.click(generate, [prompt, max_new_tokens, temperature, top_k], output) if __name__ == "__main__": demo.launch()