TLH01 commited on
Commit
4507ad7
·
verified ·
1 Parent(s): 7f66618

Update apptest.py

Browse files
Files changed (1) hide show
  1. apptest.py +189 -60
apptest.py CHANGED
@@ -1,84 +1,213 @@
 
 
 
 
 
1
  import streamlit as st
2
  from PIL import Image
3
  import tempfile
4
  import numpy as np
5
- from transformers import pipeline, set_seed
6
- import soundfile as sf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- # --- 模型初始化(缓存优化)---
 
 
9
  @st.cache_resource
10
- def load_models():
11
- caption_pipeline = pipeline(
 
12
  "image-to-text",
13
  model="Salesforce/blip-image-captioning-base",
14
  device="cuda" if torch.cuda.is_available() else "cpu"
15
  )
16
- story_pipeline = pipeline(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  "text-generation",
18
  model="pranavpsv/gpt2-genre-story-generator",
19
  device="cuda" if torch.cuda.is_available() else "cpu"
20
  )
21
- tts_pipeline = pipeline(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  "text-to-speech",
23
- model="speechbrain/tts-tacotron2-ljspeech",
24
  device="cuda" if torch.cuda.is_available() else "cpu"
25
  )
26
- return caption_pipeline, story_pipeline, tts_pipeline
27
-
28
- # --- Stage 1: Image → Caption ---
29
- def generate_caption(image, pipeline):
30
- caption = pipeline(image)[0]['generated_text']
31
- return caption
32
-
33
- # --- Stage 2: Caption → Story (严格限制字数) ---
34
- def generate_story(caption, pipeline):
35
- prompt = f"Generate a children's story in 50-100 words about: {caption}"
36
- story = pipeline(
37
- prompt,
38
- max_length=150, # Token数量(约对应100词)
39
- min_length=80, # 约对应50词
40
- do_sample=True,
41
- temperature=0.7,
42
- top_k=50,
43
- num_return_sequences=1
44
- )[0]['generated_text']
45
- # 移除重复提示并截断
46
- story = story.replace(prompt, "").strip().split(".")[:5] # 取前5个句子
47
- return ".".join(story[:5]) + "." # 确保以句号结尾
48
 
49
- # --- Stage 3: Story → Audio (兼容Spaces) ---
50
- def generate_audio(story_text, pipeline):
51
- speech = pipeline(story_text)
52
- audio_array = speech["audio"].squeeze().numpy()
53
- sample_rate = speech["sampling_rate"]
54
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
55
- sf.write(f.name, audio_array, sample_rate)
56
- return f.name
 
 
 
57
 
58
- # --- Streamlit UI ---
 
 
59
  def main():
60
- st.title("📖 AI Storyteller for Kids")
61
- caption_pipeline, story_pipeline, tts_pipeline = load_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- uploaded_image = st.file_uploader("Upload a child-friendly image", type=["jpg", "jpeg", "png"])
64
- if uploaded_image:
65
- image = Image.open(uploaded_image)
66
- st.image(image, use_column_width=True)
67
-
68
- with st.spinner("🔍 Analyzing the image..."):
69
- caption = generate_caption(image, caption_pipeline)
70
- st.success(f"📝 Caption: {caption}")
71
-
72
- with st.spinner("✨ Creating a magical story..."):
73
- story = generate_story(caption, story_pipeline)
74
- st.subheader("📚 Your Story")
75
- st.write(story)
76
- st.info(f"Word count: {len(story.split())}") # 显示字数
77
-
78
- with st.spinner("🔊 Generating audio..."):
79
- audio_path = generate_audio(story, tts_pipeline)
80
- st.audio(audio_path, format="audio/wav")
 
81
 
82
  if __name__ == "__main__":
83
- import torch # 延迟导入以避免Spaces预加载问题
 
84
  main()
 
1
+ """
2
+ Magic Story Generator App for Hugging Face Spaces
3
+ Creates custom children's stories from uploaded images
4
+ """
5
+
6
  import streamlit as st
7
  from PIL import Image
8
  import tempfile
9
  import numpy as np
10
+ from transformers import pipeline
11
+ import torch
12
+ import os
13
+
14
+ # ======================
15
+ # UI Configuration
16
+ # ======================
17
+ def configure_ui():
18
+ """Sets up child-friendly interface with custom styling"""
19
+ st.set_page_config(
20
+ page_title="✨ Magic Story Generator",
21
+ page_icon="🧚",
22
+ layout="wide"
23
+ )
24
+
25
+ # Custom CSS for child-friendly design
26
+ st.markdown("""
27
+ <style>
28
+ .main {
29
+ background-color: #FFF5E6;
30
+ background-image: url('https://img.freepik.com/free-vector/hand-drawn-childish-pattern_23-2149073136.jpg');
31
+ background-size: 30%;
32
+ opacity: 0.9;
33
+ }
34
+ h1 {
35
+ color: #FF6B6B;
36
+ font-family: 'Comic Sans MS', cursive;
37
+ text-align: center;
38
+ text-shadow: 2px 2px 4px #FFD166;
39
+ }
40
+ .stButton>button {
41
+ background-color: #4ECDC4;
42
+ color: white;
43
+ border-radius: 20px;
44
+ padding: 10px 24px;
45
+ font-weight: bold;
46
+ }
47
+ .stFileUploader>div>div>div>div {
48
+ border: 2px dashed #FF9E7D;
49
+ border-radius: 15px;
50
+ background-color: #FFF0F5;
51
+ }
52
+ .story-box {
53
+ background-color: #FFF0F5;
54
+ padding: 20px;
55
+ border-radius: 15px;
56
+ border-left: 5px solid #FF6B6B;
57
+ font-family: 'Comic Sans MS', cursive;
58
+ font-size: 18px;
59
+ line-height: 1.6;
60
+ }
61
+ </style>
62
+ """, unsafe_allow_html=True)
63
 
64
+ # ======================
65
+ # Stage 1: Image Captioning
66
+ # ======================
67
  @st.cache_resource
68
+ def load_image_captioner():
69
+ """Loads BLIP image captioning model with GPU support if available"""
70
+ return pipeline(
71
  "image-to-text",
72
  model="Salesforce/blip-image-captioning-base",
73
  device="cuda" if torch.cuda.is_available() else "cpu"
74
  )
75
+
76
+ def generate_caption(_pipeline, image):
77
+ """Generates English description of uploaded image"""
78
+ try:
79
+ # Generate caption with 20-50 words
80
+ result = _pipeline(image, max_new_tokens=50)
81
+ return result[0]['generated_text']
82
+ except Exception as e:
83
+ st.error(f"Caption generation failed: {str(e)}")
84
+ return None
85
+
86
+ # ======================
87
+ # Stage 2: Story Generation
88
+ # ======================
89
+ @st.cache_resource
90
+ def load_story_generator():
91
+ """Loads fine-tuned GPT-2 story generator"""
92
+ return pipeline(
93
  "text-generation",
94
  model="pranavpsv/gpt2-genre-story-generator",
95
  device="cuda" if torch.cuda.is_available() else "cpu"
96
  )
97
+
98
+ def generate_story(_pipeline, keywords):
99
+ """Creates a children's story (60-100 words) based on image caption"""
100
+ prompt = f"""Generate a children's story (60-100 words) based on: {keywords}
101
+ Requirements:
102
+ - Use simple English (Grade 2 level)
103
+ - Include magical/fantasy elements
104
+ - Have positive moral lesson
105
+ - Happy ending
106
+ - Exactly 3 paragraphs
107
+
108
+ Story:"""
109
+
110
+ try:
111
+ story = _pipeline(
112
+ prompt,
113
+ max_length=250, # Controls token count (~100 words)
114
+ temperature=0.7, # Balance creativity vs coherence
115
+ do_sample=True,
116
+ top_k=50
117
+ )[0]['generated_text']
118
+ # Clean up output by removing prompt
119
+ return story.replace(prompt, "").strip()
120
+ except Exception as e:
121
+ st.error(f"Story generation failed: {str(e)}")
122
+ return None
123
+
124
+ # ======================
125
+ # Stage 3: Text-to-Speech
126
+ # ======================
127
+ @st.cache_resource
128
+ def load_tts():
129
+ """Loads multilingual TTS model"""
130
+ return pipeline(
131
  "text-to-speech",
132
+ model="facebook/mms-tts-eng",
133
  device="cuda" if torch.cuda.is_available() else "cpu"
134
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ def text_to_speech(_pipeline, text):
137
+ """Converts generated story to speech audio"""
138
+ try:
139
+ audio = _pipeline(text)
140
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
141
+ import soundfile as sf
142
+ sf.write(f.name, audio["audio"].squeeze().numpy(), audio["sampling_rate"])
143
+ return f.name
144
+ except Exception as e:
145
+ st.error(f"Audio generation failed: {str(e)}")
146
+ return None
147
 
148
+ # ======================
149
+ # Main Application
150
+ # ======================
151
  def main():
152
+ # Configure UI first
153
+ configure_ui()
154
+
155
+ # App header
156
+ st.title("🧚 Magic Story Generator")
157
+ st.markdown("""
158
+ <div style="text-align:center; color:#FF8E72; font-family: 'Comic Sans MS'; font-size: 20px;">
159
+ Upload a child's photo and AI will create a custom fairy tale with audio!
160
+ </div>
161
+ """, unsafe_allow_html=True)
162
+
163
+ # File upload section
164
+ uploaded_file = st.file_uploader(
165
+ "Choose a photo of children's activity",
166
+ type=["jpg", "jpeg", "png"],
167
+ help="Examples: playing, reading, drawing etc."
168
+ )
169
+
170
+ if not uploaded_file:
171
+ st.info("👆 Please upload an image to begin")
172
+ return
173
+
174
+ # Display uploaded image
175
+ image = Image.open(uploaded_file)
176
+ st.image(image, caption="Your uploaded photo", use_column_width=True)
177
+
178
+ # Load all models (shows loading animation)
179
+ with st.spinner("🪄 Preparing magic tools..."):
180
+ caption_pipe = load_image_captioner()
181
+ story_pipe = load_story_generator()
182
+ tts_pipe = load_tts()
183
+
184
+ # --- Stage 1: Image Captioning ---
185
+ with st.spinner("🔍 Analyzing the image..."):
186
+ caption = generate_caption(caption_pipe, image)
187
+ if caption:
188
+ st.success(f"📝 AI sees: {caption}")
189
 
190
+ # --- Stage 2: Story Generation ---
191
+ if caption:
192
+ with st.spinner("✍️ Writing your story..."):
193
+ story = generate_story(story_pipe, caption)
194
+ if story:
195
+ st.subheader("📖 Your Custom Story")
196
+ st.markdown(f"""
197
+ <div class="story-box">
198
+ {story}
199
+ </div>
200
+ """, unsafe_allow_html=True)
201
+
202
+ # --- Stage 3: Text-to-Speech ---
203
+ with st.spinner("🔊 Creating audio version..."):
204
+ audio_path = text_to_speech(tts_pipe, story)
205
+ if audio_path:
206
+ st.audio(audio_path, format="audio/wav")
207
+ st.success("Audio ready! Click play above to listen")
208
+ st.balloons() # Celebration animation
209
 
210
  if __name__ == "__main__":
211
+ # Set Hugging Face cache location
212
+ os.environ["HF_HUB_CACHE"] = "/tmp/huggingface"
213
  main()