AiCoderv2's picture
Update app.py
8111aa7 verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import gradio as gr
# Load smaller GPT-2 model
model_name = "gpt2" # smaller and faster
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Create generator pipeline for CPU
generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=-1)
def generate_data(prompt, amount):
# Generate multiple samples in batch
responses = generator(
prompt,
max_length=50, # keep short for speed
num_return_sequences=amount,
do_sample=False, # greedy for speed
temperature=0.7,
top_k=50,
top_p=0.95,
pad_token_id=tokenizer.eos_token_id,
num_beams=1 # greedy
)
return [resp['generated_text'].strip() for resp in responses]
with gr.Blocks() as demo:
gr.Markdown("### Faster Data Generator with GPT-2\nDescribe what data you want to generate.")
prompt_input = gr.Textbox(label="Prompt / Data Type", placeholder="Describe the data you want")
amount_input = gr.Slider(1, 10, value=3, step=1, label="Number of Data Items")
output_box = gr.Textbox(label="Generated Data", lines=15)
generate_btn = gr.Button("Generate")
generate_btn.click(generate_data, inputs=[prompt_input, amount_input], outputs=output_box)
demo.launch()