to0ony's picture
added app.py, requirements.txt and mingpt
8fccddb
raw
history blame
2.32 kB
import gc, json, torch, gradio as gr
from huggingface_hub import hf_hub_download
import tiktoken
from mingpt.model import GPT
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
REPO_ID = "to0ony/final-thesis-plotgen"
state = {"model": None, "enc": tiktoken.get_encoding("gpt2")}
def load_model():
"""Lazy-load model iz Hugging Face repozitorija"""
if state["model"] is not None:
return state["model"]
# skinuti config i model.pt
cfg_path = hf_hub_download(repo_id=REPO_ID, filename="config.json")
mdl_path = hf_hub_download(repo_id=REPO_ID, filename="model.pt")
# učitati config
with open(cfg_path, "r", encoding="utf-8") as f:
cfg = json.load(f)
gcfg = GPT.get_default_config()
gcfg.vocab_size = cfg["vocab_size"]
gcfg.block_size = cfg["block_size"]
gcfg.n_layer = cfg["n_layer"]
gcfg.n_head = cfg["n_head"]
gcfg.n_embd = cfg["n_embd"]
model = GPT(gcfg)
sd = torch.load(mdl_path, map_location="cpu")
model.load_state_dict(sd, strict=True)
model.to(DEVICE)
model.eval()
state["model"] = model
return model
@torch.inference_mode()
def generate(prompt, max_new_tokens=200, temperature=0.9, top_k=50):
"""Generiranje teksta iz prompta"""
model = load_model()
enc = state["enc"]
x = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=DEVICE)
y = model.generate(
x,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_k=int(top_k) if top_k > 0 else None
)
return enc.decode(y[0].tolist())
# Gradio UI
with gr.Blocks(title="🎬 Final Thesis Plot Generator") as demo:
gr.Markdown("## 🎬 Film Plot Generator\nUnesi prompt i generiraj radnju filma.")
prompt = gr.Textbox(label="Prompt", lines=5, placeholder="E.g. A young detective arrives in a coastal town...")
max_new_tokens = gr.Slider(32, 512, value=200, step=16, label="Max new tokens")
temperature = gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Temperature")
top_k = gr.Slider(0, 100, value=50, step=5, label="Top-K (0 = off)")
btn = gr.Button("Generate")
output = gr.Textbox(label="Output", lines=15)
btn.click(generate, [prompt, max_new_tokens, temperature, top_k], output)
if __name__ == "__main__":
demo.launch()