Spaces:
Sleeping
Sleeping
| import gc, json, torch, gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import tiktoken | |
| from mingpt.model import GPT | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| REPO_ID = "to0ony/final-thesis-plotgen" | |
| state = {"model": None, "enc": tiktoken.get_encoding("gpt2")} | |
| def load_model(): | |
| """Lazy-load model iz Hugging Face repozitorija""" | |
| if state["model"] is not None: | |
| return state["model"] | |
| # skinuti config i model.pt | |
| cfg_path = hf_hub_download(repo_id=REPO_ID, filename="config.json") | |
| mdl_path = hf_hub_download(repo_id=REPO_ID, filename="model.pt") | |
| # učitati config | |
| with open(cfg_path, "r", encoding="utf-8") as f: | |
| cfg = json.load(f) | |
| gcfg = GPT.get_default_config() | |
| gcfg.vocab_size = cfg["vocab_size"] | |
| gcfg.block_size = cfg["block_size"] | |
| gcfg.n_layer = cfg["n_layer"] | |
| gcfg.n_head = cfg["n_head"] | |
| gcfg.n_embd = cfg["n_embd"] | |
| model = GPT(gcfg) | |
| sd = torch.load(mdl_path, map_location="cpu") | |
| model.load_state_dict(sd, strict=True) | |
| model.to(DEVICE) | |
| model.eval() | |
| state["model"] = model | |
| return model | |
| def generate(prompt, max_new_tokens=200, temperature=0.9, top_k=50): | |
| """Generiranje teksta iz prompta""" | |
| model = load_model() | |
| enc = state["enc"] | |
| x = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=DEVICE) | |
| y = model.generate( | |
| x, | |
| max_new_tokens=int(max_new_tokens), | |
| temperature=float(temperature), | |
| top_k=int(top_k) if top_k > 0 else None | |
| ) | |
| return enc.decode(y[0].tolist()) | |
| # Gradio UI | |
| with gr.Blocks(title="🎬 Final Thesis Plot Generator") as demo: | |
| gr.Markdown("## 🎬 Film Plot Generator\nUnesi prompt i generiraj radnju filma.") | |
| prompt = gr.Textbox(label="Prompt", lines=5, placeholder="E.g. A young detective arrives in a coastal town...") | |
| max_new_tokens = gr.Slider(32, 512, value=200, step=16, label="Max new tokens") | |
| temperature = gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Temperature") | |
| top_k = gr.Slider(0, 100, value=50, step=5, label="Top-K (0 = off)") | |
| btn = gr.Button("Generate") | |
| output = gr.Textbox(label="Output", lines=15) | |
| btn.click(generate, [prompt, max_new_tokens, temperature, top_k], output) | |
| if __name__ == "__main__": | |
| demo.launch() | |