Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -50,7 +50,7 @@ def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: floa
|
|
| 50 |
|
| 51 |
Returns:
|
| 52 |
results_by_cat: Dictionary mapping each category to a list of (tag, probability) above its threshold.
|
| 53 |
-
prompt_tags_by_cat: Dictionary for prompt-style output (
|
| 54 |
all_artist_tags: All artist tags (with probabilities) regardless of threshold.
|
| 55 |
"""
|
| 56 |
probs = 1 / (1 + np.exp(-refined_logits))
|
|
@@ -59,7 +59,8 @@ def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: floa
|
|
| 59 |
category_thresholds = metadata.get("category_thresholds", {})
|
| 60 |
|
| 61 |
results_by_cat = {}
|
| 62 |
-
|
|
|
|
| 63 |
all_artist_tags = []
|
| 64 |
|
| 65 |
for idx, prob in enumerate(probs):
|
|
@@ -77,22 +78,29 @@ def get_tags(refined_logits: np.ndarray, metadata: dict, default_threshold: floa
|
|
| 77 |
def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str:
|
| 78 |
"""
|
| 79 |
Format the tags for prompt-style output.
|
|
|
|
| 80 |
|
| 81 |
Returns a comma-separated string of escaped tags.
|
| 82 |
"""
|
| 83 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
for cat in prompt_tags_by_cat:
|
| 85 |
prompt_tags_by_cat[cat].sort(key=lambda x: x[1], reverse=True)
|
| 86 |
|
| 87 |
-
artist_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("artist", [])]
|
| 88 |
character_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("character", [])]
|
| 89 |
general_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("general", [])]
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
| 96 |
return ", ".join(prompt_tags) if prompt_tags else "No tags predicted."
|
| 97 |
|
| 98 |
def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str:
|
|
@@ -145,8 +153,7 @@ with demo:
|
|
| 145 |
"Upload an image, adjust the threshold, and click **Tag Image** to see predictions."
|
| 146 |
)
|
| 147 |
gr.Markdown(
|
| 148 |
-
"*(Note:
|
| 149 |
-
"You can choose a concise prompt-style output or a detailed category-wise breakdown.)*"
|
| 150 |
)
|
| 151 |
with gr.Row():
|
| 152 |
with gr.Column():
|
|
@@ -162,7 +169,7 @@ with demo:
|
|
| 162 |
maximum=1.0,
|
| 163 |
step=0.05,
|
| 164 |
value=DEFAULT_THRESHOLD,
|
| 165 |
-
label="Threshold"
|
| 166 |
)
|
| 167 |
tag_button = gr.Button("🔍 Tag Image")
|
| 168 |
with gr.Column():
|
|
|
|
| 50 |
|
| 51 |
Returns:
|
| 52 |
results_by_cat: Dictionary mapping each category to a list of (tag, probability) above its threshold.
|
| 53 |
+
prompt_tags_by_cat: Dictionary for prompt-style output (character, general).
|
| 54 |
all_artist_tags: All artist tags (with probabilities) regardless of threshold.
|
| 55 |
"""
|
| 56 |
probs = 1 / (1 + np.exp(-refined_logits))
|
|
|
|
| 59 |
category_thresholds = metadata.get("category_thresholds", {})
|
| 60 |
|
| 61 |
results_by_cat = {}
|
| 62 |
+
# For prompt style, only include character and general tags (artists handled separately)
|
| 63 |
+
prompt_tags_by_cat = {"character": [], "general": []}
|
| 64 |
all_artist_tags = []
|
| 65 |
|
| 66 |
for idx, prob in enumerate(probs):
|
|
|
|
| 78 |
def format_prompt_tags(prompt_tags_by_cat: dict, all_artist_tags: list) -> str:
|
| 79 |
"""
|
| 80 |
Format the tags for prompt-style output.
|
| 81 |
+
Only the top artist tag is shown (regardless of threshold), and all character and general tags are shown.
|
| 82 |
|
| 83 |
Returns a comma-separated string of escaped tags.
|
| 84 |
"""
|
| 85 |
+
# Always select the best artist tag from all_artist_tags, regardless of threshold.
|
| 86 |
+
best_artist_tag = None
|
| 87 |
+
if all_artist_tags:
|
| 88 |
+
best_artist = max(all_artist_tags, key=lambda item: item[1])
|
| 89 |
+
best_artist_tag = escape_tag(best_artist[0])
|
| 90 |
+
|
| 91 |
+
# Sort character and general tags by probability (descending)
|
| 92 |
for cat in prompt_tags_by_cat:
|
| 93 |
prompt_tags_by_cat[cat].sort(key=lambda x: x[1], reverse=True)
|
| 94 |
|
|
|
|
| 95 |
character_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("character", [])]
|
| 96 |
general_tags = [escape_tag(tag) for tag, _ in prompt_tags_by_cat.get("general", [])]
|
| 97 |
+
|
| 98 |
+
prompt_tags = []
|
| 99 |
+
if best_artist_tag:
|
| 100 |
+
prompt_tags.append(best_artist_tag)
|
| 101 |
+
prompt_tags.extend(character_tags)
|
| 102 |
+
prompt_tags.extend(general_tags)
|
| 103 |
+
|
| 104 |
return ", ".join(prompt_tags) if prompt_tags else "No tags predicted."
|
| 105 |
|
| 106 |
def format_detailed_output(results_by_cat: dict, all_artist_tags: list) -> str:
|
|
|
|
| 153 |
"Upload an image, adjust the threshold, and click **Tag Image** to see predictions."
|
| 154 |
)
|
| 155 |
gr.Markdown(
|
| 156 |
+
"*(Note: In prompt-style output, only the top artist tag is displayed along with all character and general tags.)*"
|
|
|
|
| 157 |
)
|
| 158 |
with gr.Row():
|
| 159 |
with gr.Column():
|
|
|
|
| 169 |
maximum=1.0,
|
| 170 |
step=0.05,
|
| 171 |
value=DEFAULT_THRESHOLD,
|
| 172 |
+
label="Default Threshold"
|
| 173 |
)
|
| 174 |
tag_button = gr.Button("🔍 Tag Image")
|
| 175 |
with gr.Column():
|