Spaces:
Sleeping
Sleeping
File size: 6,202 Bytes
4e89892 cf85a3a 4e89892 ab4dfe6 cf85a3a ab4dfe6 4e89892 cf85a3a ab4dfe6 cf85a3a 4e89892 cf85a3a ab4dfe6 4e89892 cf85a3a 4e89892 ab4dfe6 4e89892 cf85a3a 0b37765 ab4dfe6 0b37765 4e89892 ab4dfe6 0b37765 ab4dfe6 4e89892 0b37765 4e89892 0b37765 4e89892 0b37765 4e89892 0b37765 4e89892 0b37765 4e89892 0b37765 4e89892 0b37765 4e89892 0b37765 4e89892 0b37765 4e89892 0b37765 4e89892 0b37765 4e89892 0b37765 4e89892 0b37765 4e89892 0b37765 4e89892 0b37765 4e89892 0b37765 4e89892 cf85a3a 0b37765 cf85a3a 4e89892 cf85a3a 4e89892 cf85a3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import re
import gradio as gr
import torch
from transformers import PreTrainedTokenizerFast, T5ForConditionalGeneration
# โ
KoT5 ์์ฝ ๋ชจ๋ธ
MODEL_NAME = "psyche/KoT5-summarization"
tokenizer = PreTrainedTokenizerFast.from_pretrained(MODEL_NAME)
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
# CPU ๋์ ์์ํ ์ ์ฉ
try:
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
except:
pass
model.eval()
# ===== ์ ํธ =====
def normalize_text(text: str) -> str:
return re.sub(r"\s+", " ", text).strip()
def split_into_sentences(text: str):
text = text.replace("\n", " ")
parts = re.split(r"(?<=[\.!?])\s+", text)
return [p.strip() for p in parts if p.strip()]
def token_length(s: str) -> int:
return len(tokenizer.encode(s, add_special_tokens=False))
def chunk_by_tokens(sentences, max_tokens=900):
chunks, cur, cur_tokens = [], [], 0
for s in sentences:
tl = token_length(s)
if tl > max_tokens:
piece_size = max(200, int(len(s) * (max_tokens / tl)))
for i in range(0, len(s), piece_size):
sub = s[i:i+piece_size]
if sub.strip():
chunks.append(sub.strip())
cur, cur_tokens = [], 0
continue
if cur_tokens + tl <= max_tokens:
cur.append(s)
cur_tokens += tl
else:
if cur:
chunks.append(" ".join(cur))
cur, cur_tokens = [s], tl
if cur:
chunks.append(" ".join(cur))
return chunks
# ===== ๋ฐ๋ณต ์ ๊ฑฐ =====
def derpeat(text: str) -> str:
text = re.sub(r'(.)\1{2,}', r'\1\1', text) # ๋จ์ผ ๋ฌธ์ 3ํ ์ด์ ๋ฐ๋ณต โ 2ํ
text = re.sub(r'(\b\w+\b)(\s+\1){1,}', r'\1', text) # ๋จ์ด ๋ฐ๋ณต ์ ๊ฑฐ
text = re.sub(r'([\.!?\-~])\1{2,}', r'\1\1', text) # ๊ตฌ๋์ ๋ฐ๋ณต ์ถ์
return text.strip()
# ===== ์์ฝ =====
def approx_tokens_from_chars(n_chars: int) -> int:
return max(1, n_chars // 2) # ํ๊ธ ๋๋ต 1ํ ํฐ โ 2๋ฌธ์
def summarize_raw_t5(input_text: str, target_chars: int, input_tokens: int) -> str:
safe_target_chars = min(target_chars, max(120, int(len(input_text) * 0.9)))
max_new = max(40, min(approx_tokens_from_chars(safe_target_chars), 300))
if input_tokens <= 200:
max_new = min(max_new, max(40, int(input_tokens * 0.6)))
if input_tokens <= 60:
max_new = min(max_new, 60)
input_ids = tokenizer.encode(input_text, return_tensors="pt", truncation=True, max_length=1024)
with torch.no_grad():
summary_ids = model.generate(
input_ids,
max_new_tokens=max_new,
do_sample=True,
top_p=0.92,
temperature=0.7,
num_beams=1,
no_repeat_ngram_size=4,
encoder_no_repeat_ngram_size=4,
repetition_penalty=1.2,
renormalize_logits=True,
early_stopping=True
)
return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
def apply_style_prompt_t5(text: str, mode: str, final: bool=False) -> str:
if mode == "concise":
tag = "๊ฐ๊ฒฐ ์์ฝ:"
elif mode == "explanatory":
tag = "์ค๋ช
์์ฝ:"
else:
tag = "๋ถ๋ฆฟ ์์ฝ:"
guide = ""
if final:
guide = " (์๋ ๋ฌธ์์ ์์๋ฅผ ์ ์งํ๊ณ ์ค๋ณต์ ์ ๊ฑฐํ์ธ์.)"
return f"{tag}{guide}\n{text}"
def postprocess_strict(summary: str, mode: str) -> str:
s = summary.strip()
s = re.sub(r"\s+", " ", s)
s = derpeat(s)
seen, outs = set(), []
for sent in re.split(r"(?<=[\.!?])\s+", s):
ss = sent.strip()
if ss and ss not in seen:
seen.add(ss)
outs.append(ss)
s = " ".join(outs)
if mode == "bullets":
parts = [p for p in outs if p]
s = "\n".join([f"- {p}" for p in parts[:12]])
return s
def summarize_long(text: str, target_chars: int, mode: str):
text = normalize_text(text)
if not text:
return "โ ๏ธ ์์ฝํ ํ
์คํธ๋ฅผ ์
๋ ฅํ์ธ์."
approx_tokens = token_length(text)
if approx_tokens <= 60:
prompt = apply_style_prompt_t5(text, mode, final=False)
out = summarize_raw_t5(prompt, min(target_chars, 300), approx_tokens)
return postprocess_strict(out, mode)
if approx_tokens <= 1000:
prompt = apply_style_prompt_t5(text, mode, final=False)
out = summarize_raw_t5(prompt, target_chars, approx_tokens)
return postprocess_strict(out, mode)
sentences = split_into_sentences(text)
chunks = chunk_by_tokens(sentences, max_tokens=900)
partial_summaries = []
per_chunk_chars = max(180, int(target_chars * 1.2 / max(1, len(chunks))))
for c in chunks:
prompt = apply_style_prompt_t5(c, mode, final=False)
psum = summarize_raw_t5(prompt, per_chunk_chars, token_length(c))
partial_summaries.append(psum)
merged = normalize_text(" ".join(partial_summaries))
merged = derpeat(merged)
final_prompt = apply_style_prompt_t5(merged, mode, final=True)
final = summarize_raw_t5(final_prompt, target_chars, token_length(merged))
return postprocess_strict(final, mode)
# ===== Gradio UI =====
def ui_summarize(text, target_len, style):
mode = {"๊ฐ๊ฒฐํ":"concise", "์ค๋ช
ํ":"explanatory", "ํต์ฌ bullet":"bullets"}[style]
return summarize_long(text, int(target_len), mode)
with gr.Blocks() as demo:
gr.Markdown("## ๐ KoT5 ํ๊ตญ์ด ์์ฝ๊ธฐ (๋ฐ๋ณต ์ต์ + ์์ ๋ณด์กด)")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="์๋ฌธ ์
๋ ฅ", lines=16)
style = gr.Radio(["๊ฐ๊ฒฐํ", "์ค๋ช
ํ", "ํต์ฌ bullet"], value="๊ฐ๊ฒฐํ", label="์์ฝ ์คํ์ผ")
target_len = gr.Slider(300, 1500, value=1000, step=50, label="๋ชฉํ ์์ฝ ๊ธธ์ด(๋ฌธ์)")
btn = gr.Button("์์ฝ ์คํ")
with gr.Column():
output_text = gr.Textbox(label="์์ฝ ๊ฒฐ๊ณผ", lines=16)
btn.click(ui_summarize, inputs=[input_text, target_len, style], outputs=output_text)
if __name__ == "__main__":
demo.launch() |