Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -450,6 +450,113 @@ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision
|
|
| 450 |
cleanup_temp()
|
| 451 |
return f"Done! Merged {buffer.shard_count} shards to {output_repo}"
|
| 452 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
# =================================================================================
|
| 454 |
# TAB 4: RESIZE (CPU Optimized)
|
| 455 |
# =================================================================================
|
|
|
|
| 450 |
cleanup_temp()
|
| 451 |
return f"Done! Merged {buffer.shard_count} shards to {output_repo}"
|
| 452 |
|
| 453 |
+
# =================================================================================
|
| 454 |
+
# TAB 2: EXTRACT LORA
|
| 455 |
+
# =================================================================================
|
| 456 |
+
|
| 457 |
+
def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp):
|
| 458 |
+
org = MemoryEfficientSafeOpen(model_org)
|
| 459 |
+
tuned = MemoryEfficientSafeOpen(model_tuned)
|
| 460 |
+
lora_sd = {}
|
| 461 |
+
print("Calculating diffs...")
|
| 462 |
+
for key in tqdm(org.keys()):
|
| 463 |
+
if key not in tuned.keys(): continue
|
| 464 |
+
mat_org = org.get_tensor(key).float()
|
| 465 |
+
mat_tuned = tuned.get_tensor(key).float()
|
| 466 |
+
diff = mat_tuned - mat_org
|
| 467 |
+
if torch.max(torch.abs(diff)) < 1e-4: continue
|
| 468 |
+
|
| 469 |
+
out_dim, in_dim = diff.shape[:2]
|
| 470 |
+
r = min(rank, in_dim, out_dim)
|
| 471 |
+
is_conv = len(diff.shape) == 4
|
| 472 |
+
if is_conv: diff = diff.flatten(start_dim=1)
|
| 473 |
+
|
| 474 |
+
try:
|
| 475 |
+
U, S, Vh = torch.linalg.svd(diff, full_matrices=False)
|
| 476 |
+
U, S, Vh = U[:, :r], S[:r], Vh[:r, :]
|
| 477 |
+
U = U @ torch.diag(S)
|
| 478 |
+
dist = torch.cat([U.flatten(), Vh.flatten()])
|
| 479 |
+
hi_val = torch.quantile(dist, clamp)
|
| 480 |
+
U = U.clamp(-hi_val, hi_val)
|
| 481 |
+
Vh = Vh.clamp(-hi_val, hi_val)
|
| 482 |
+
if is_conv:
|
| 483 |
+
U = U.reshape(out_dim, r, 1, 1)
|
| 484 |
+
Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3])
|
| 485 |
+
else:
|
| 486 |
+
U = U.reshape(out_dim, r)
|
| 487 |
+
Vh = Vh.reshape(r, in_dim)
|
| 488 |
+
stem = key.replace(".weight", "")
|
| 489 |
+
lora_sd[f"{stem}.lora_up.weight"] = U
|
| 490 |
+
lora_sd[f"{stem}.lora_down.weight"] = Vh
|
| 491 |
+
lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
|
| 492 |
+
except: pass
|
| 493 |
+
out = TempDir / "extracted.safetensors"
|
| 494 |
+
save_file(lora_sd, out)
|
| 495 |
+
return str(out)
|
| 496 |
+
|
| 497 |
+
def task_extract(hf_token, org, tun, rank, out):
|
| 498 |
+
cleanup_temp()
|
| 499 |
+
if hf_token: login(hf_token.strip())
|
| 500 |
+
try:
|
| 501 |
+
p1 = download_file(org, hf_token, filename="org.safetensors")
|
| 502 |
+
p2 = download_file(tun, hf_token, filename="tun.safetensors")
|
| 503 |
+
f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99)
|
| 504 |
+
api.create_repo(repo_id=out, exist_ok=True, token=hf_token)
|
| 505 |
+
api.upload_file(path_or_fileobj=f, path_in_repo="extracted.safetensors", repo_id=out, token=hf_token)
|
| 506 |
+
return "Done"
|
| 507 |
+
except Exception as e: return f"Error: {e}"
|
| 508 |
+
|
| 509 |
+
# =================================================================================
|
| 510 |
+
# TAB 3: MERGE ADAPTERS (EMA)
|
| 511 |
+
# =================================================================================
|
| 512 |
+
|
| 513 |
+
def sigma_rel_to_gamma(sigma_rel):
|
| 514 |
+
t = sigma_rel**-2
|
| 515 |
+
coeffs = [1, 7, 16 - t, 12 - t]
|
| 516 |
+
roots = np.roots(coeffs)
|
| 517 |
+
gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max()
|
| 518 |
+
return gamma
|
| 519 |
+
|
| 520 |
+
def task_merge_adapters(hf_token, lora_urls, beta, sigma_rel, out_repo):
|
| 521 |
+
cleanup_temp()
|
| 522 |
+
if hf_token: login(hf_token.strip())
|
| 523 |
+
|
| 524 |
+
urls = [u.strip() for u in lora_urls.split(",") if u.strip()]
|
| 525 |
+
paths = []
|
| 526 |
+
try:
|
| 527 |
+
for i, url in enumerate(urls):
|
| 528 |
+
paths.append(download_lora_smart(url, hf_token))
|
| 529 |
+
except Exception as e: return f"Download Error: {e}"
|
| 530 |
+
|
| 531 |
+
if not paths: return "No models found"
|
| 532 |
+
|
| 533 |
+
base_sd = load_file(paths[0], device="cpu")
|
| 534 |
+
for k in base_sd:
|
| 535 |
+
if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float()
|
| 536 |
+
|
| 537 |
+
gamma = None
|
| 538 |
+
if sigma_rel > 0:
|
| 539 |
+
gamma = sigma_rel_to_gamma(sigma_rel)
|
| 540 |
+
|
| 541 |
+
for i, path in enumerate(paths[1:]):
|
| 542 |
+
print(f"Merging {path}")
|
| 543 |
+
if gamma is not None:
|
| 544 |
+
t = i + 1
|
| 545 |
+
current_beta = (1 - 1 / t) ** (gamma + 1)
|
| 546 |
+
else:
|
| 547 |
+
current_beta = beta
|
| 548 |
+
|
| 549 |
+
curr = load_file(path, device="cpu")
|
| 550 |
+
for k in base_sd:
|
| 551 |
+
if k in curr and "alpha" not in k:
|
| 552 |
+
base_sd[k] = base_sd[k] * current_beta + curr[k].float() * (1 - current_beta)
|
| 553 |
+
|
| 554 |
+
out = TempDir / "merged_adapters.safetensors"
|
| 555 |
+
save_file(base_sd, out)
|
| 556 |
+
api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
|
| 557 |
+
api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token)
|
| 558 |
+
return "Done"
|
| 559 |
+
|
| 560 |
# =================================================================================
|
| 561 |
# TAB 4: RESIZE (CPU Optimized)
|
| 562 |
# =================================================================================
|