AlekseyCalvin commited on
Commit
62e3558
·
verified ·
1 Parent(s): e455bfb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -0
app.py CHANGED
@@ -450,6 +450,113 @@ def task_merge(hf_token, base_repo, base_subfolder, lora_input, scale, precision
450
  cleanup_temp()
451
  return f"Done! Merged {buffer.shard_count} shards to {output_repo}"
452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  # =================================================================================
454
  # TAB 4: RESIZE (CPU Optimized)
455
  # =================================================================================
 
450
  cleanup_temp()
451
  return f"Done! Merged {buffer.shard_count} shards to {output_repo}"
452
 
453
+ # =================================================================================
454
+ # TAB 2: EXTRACT LORA
455
+ # =================================================================================
456
+
457
+ def extract_lora_layer_by_layer(model_org, model_tuned, rank, clamp):
458
+ org = MemoryEfficientSafeOpen(model_org)
459
+ tuned = MemoryEfficientSafeOpen(model_tuned)
460
+ lora_sd = {}
461
+ print("Calculating diffs...")
462
+ for key in tqdm(org.keys()):
463
+ if key not in tuned.keys(): continue
464
+ mat_org = org.get_tensor(key).float()
465
+ mat_tuned = tuned.get_tensor(key).float()
466
+ diff = mat_tuned - mat_org
467
+ if torch.max(torch.abs(diff)) < 1e-4: continue
468
+
469
+ out_dim, in_dim = diff.shape[:2]
470
+ r = min(rank, in_dim, out_dim)
471
+ is_conv = len(diff.shape) == 4
472
+ if is_conv: diff = diff.flatten(start_dim=1)
473
+
474
+ try:
475
+ U, S, Vh = torch.linalg.svd(diff, full_matrices=False)
476
+ U, S, Vh = U[:, :r], S[:r], Vh[:r, :]
477
+ U = U @ torch.diag(S)
478
+ dist = torch.cat([U.flatten(), Vh.flatten()])
479
+ hi_val = torch.quantile(dist, clamp)
480
+ U = U.clamp(-hi_val, hi_val)
481
+ Vh = Vh.clamp(-hi_val, hi_val)
482
+ if is_conv:
483
+ U = U.reshape(out_dim, r, 1, 1)
484
+ Vh = Vh.reshape(r, in_dim, mat_org.shape[2], mat_org.shape[3])
485
+ else:
486
+ U = U.reshape(out_dim, r)
487
+ Vh = Vh.reshape(r, in_dim)
488
+ stem = key.replace(".weight", "")
489
+ lora_sd[f"{stem}.lora_up.weight"] = U
490
+ lora_sd[f"{stem}.lora_down.weight"] = Vh
491
+ lora_sd[f"{stem}.alpha"] = torch.tensor(r).float()
492
+ except: pass
493
+ out = TempDir / "extracted.safetensors"
494
+ save_file(lora_sd, out)
495
+ return str(out)
496
+
497
+ def task_extract(hf_token, org, tun, rank, out):
498
+ cleanup_temp()
499
+ if hf_token: login(hf_token.strip())
500
+ try:
501
+ p1 = download_file(org, hf_token, filename="org.safetensors")
502
+ p2 = download_file(tun, hf_token, filename="tun.safetensors")
503
+ f = extract_lora_layer_by_layer(p1, p2, int(rank), 0.99)
504
+ api.create_repo(repo_id=out, exist_ok=True, token=hf_token)
505
+ api.upload_file(path_or_fileobj=f, path_in_repo="extracted.safetensors", repo_id=out, token=hf_token)
506
+ return "Done"
507
+ except Exception as e: return f"Error: {e}"
508
+
509
+ # =================================================================================
510
+ # TAB 3: MERGE ADAPTERS (EMA)
511
+ # =================================================================================
512
+
513
+ def sigma_rel_to_gamma(sigma_rel):
514
+ t = sigma_rel**-2
515
+ coeffs = [1, 7, 16 - t, 12 - t]
516
+ roots = np.roots(coeffs)
517
+ gamma = roots[np.isreal(roots) & (roots.real >= 0)].real.max()
518
+ return gamma
519
+
520
+ def task_merge_adapters(hf_token, lora_urls, beta, sigma_rel, out_repo):
521
+ cleanup_temp()
522
+ if hf_token: login(hf_token.strip())
523
+
524
+ urls = [u.strip() for u in lora_urls.split(",") if u.strip()]
525
+ paths = []
526
+ try:
527
+ for i, url in enumerate(urls):
528
+ paths.append(download_lora_smart(url, hf_token))
529
+ except Exception as e: return f"Download Error: {e}"
530
+
531
+ if not paths: return "No models found"
532
+
533
+ base_sd = load_file(paths[0], device="cpu")
534
+ for k in base_sd:
535
+ if base_sd[k].dtype.is_floating_point: base_sd[k] = base_sd[k].float()
536
+
537
+ gamma = None
538
+ if sigma_rel > 0:
539
+ gamma = sigma_rel_to_gamma(sigma_rel)
540
+
541
+ for i, path in enumerate(paths[1:]):
542
+ print(f"Merging {path}")
543
+ if gamma is not None:
544
+ t = i + 1
545
+ current_beta = (1 - 1 / t) ** (gamma + 1)
546
+ else:
547
+ current_beta = beta
548
+
549
+ curr = load_file(path, device="cpu")
550
+ for k in base_sd:
551
+ if k in curr and "alpha" not in k:
552
+ base_sd[k] = base_sd[k] * current_beta + curr[k].float() * (1 - current_beta)
553
+
554
+ out = TempDir / "merged_adapters.safetensors"
555
+ save_file(base_sd, out)
556
+ api.create_repo(repo_id=out_repo, exist_ok=True, token=hf_token)
557
+ api.upload_file(path_or_fileobj=out, path_in_repo="merged_adapters.safetensors", repo_id=out_repo, token=hf_token)
558
+ return "Done"
559
+
560
  # =================================================================================
561
  # TAB 4: RESIZE (CPU Optimized)
562
  # =================================================================================