Spaces:
Runtime error
Runtime error
File size: 158,486 Bytes
0558aa4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462 2463 2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 2548 2549 2550 2551 2552 2553 2554 2555 2556 2557 2558 2559 2560 2561 2562 2563 2564 2565 2566 2567 2568 2569 2570 2571 2572 2573 2574 2575 2576 2577 2578 2579 2580 2581 2582 2583 2584 2585 2586 2587 2588 2589 2590 2591 2592 2593 2594 2595 2596 2597 2598 2599 2600 2601 2602 2603 2604 2605 2606 2607 2608 2609 2610 2611 2612 2613 2614 2615 2616 2617 2618 2619 2620 2621 2622 2623 2624 2625 2626 2627 2628 2629 2630 2631 2632 2633 2634 2635 2636 2637 2638 2639 2640 2641 2642 2643 2644 2645 2646 2647 2648 2649 2650 2651 2652 2653 2654 2655 2656 2657 2658 2659 2660 2661 2662 2663 2664 2665 2666 2667 2668 2669 2670 2671 2672 2673 2674 2675 2676 2677 2678 2679 2680 2681 2682 2683 2684 2685 2686 2687 2688 2689 2690 2691 2692 2693 2694 2695 2696 2697 2698 2699 2700 2701 2702 2703 2704 2705 2706 2707 2708 2709 2710 2711 2712 2713 2714 2715 2716 2717 2718 2719 2720 2721 2722 2723 2724 2725 2726 2727 2728 2729 2730 2731 2732 2733 2734 2735 2736 2737 2738 2739 2740 2741 2742 2743 2744 2745 2746 2747 2748 2749 2750 2751 2752 2753 2754 2755 2756 2757 2758 2759 2760 2761 2762 2763 2764 2765 2766 2767 2768 2769 2770 2771 2772 2773 2774 2775 2776 2777 2778 2779 2780 2781 2782 2783 2784 2785 2786 2787 2788 2789 2790 2791 2792 2793 2794 2795 2796 2797 2798 2799 2800 2801 2802 2803 2804 2805 2806 2807 2808 2809 2810 2811 2812 2813 2814 2815 2816 2817 2818 2819 2820 2821 2822 2823 2824 2825 2826 2827 2828 2829 2830 2831 2832 2833 2834 2835 2836 2837 2838 2839 2840 2841 2842 2843 2844 2845 2846 2847 2848 2849 2850 2851 2852 2853 2854 2855 2856 2857 2858 2859 2860 2861 2862 2863 2864 2865 2866 2867 2868 2869 2870 2871 2872 2873 2874 2875 2876 2877 2878 2879 2880 2881 2882 2883 2884 2885 2886 2887 2888 2889 2890 2891 2892 2893 2894 2895 2896 2897 2898 2899 2900 2901 2902 2903 2904 2905 2906 2907 2908 2909 2910 2911 2912 2913 2914 2915 2916 2917 2918 2919 2920 2921 2922 2923 2924 2925 2926 2927 2928 2929 2930 2931 2932 2933 2934 2935 2936 2937 2938 2939 2940 2941 2942 2943 2944 2945 2946 2947 2948 2949 2950 2951 2952 2953 2954 2955 2956 2957 2958 2959 2960 2961 2962 2963 2964 2965 2966 2967 2968 2969 2970 2971 2972 2973 2974 2975 2976 2977 2978 2979 2980 2981 2982 2983 2984 2985 2986 2987 2988 2989 2990 2991 2992 2993 2994 2995 2996 2997 2998 2999 3000 3001 3002 3003 3004 |
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import random
import time
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, List, Optional, Union
import numpy as np
import soundfile as sf
import torch
import wandb
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
from omegaconf import DictConfig, OmegaConf, open_dict
from torch import nn
from torch.utils.data import get_worker_info
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.collections.tts.data.text_to_speech_dataset_lhotse import MagpieTTSLhotseDataset, setup_tokenizers
from nemo.collections.tts.losses.aligner_loss import ForwardSumLoss
from nemo.collections.tts.models import AudioCodecModel
from nemo.collections.tts.modules import transformer_2501
from nemo.collections.tts.modules.aligner import AlignmentEncoder
from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter
from nemo.collections.tts.modules.magpietts_modules import (
CharAwareSubwordEncoder,
EOSDetectionMethod,
LocalTransformerType,
SpecialAudioToken,
cosine_schedule,
)
from nemo.collections.tts.parts.utils.helpers import (
binarize_attention_parallel,
get_mask_from_lengths,
plot_alignment_to_numpy,
)
from nemo.core.classes import ModelPT
from nemo.core.classes.common import PretrainedModelInfo
from nemo.utils import logging
@dataclass
class InferBatchOutput:
"""Output dataclass for MagpieTTS infer_batch method.
This provides a consistent return type regardless of which optional outputs
are requested.
Attributes:
predicted_audio: Generated audio waveforms. Shape: (B, T_audio).
predicted_audio_lens: Length of each audio in samples. Shape: (B,).
predicted_codes: Generated audio codec tokens. Shape: (B, num_codebooks, T_frames).
predicted_codes_lens: Length of each code sequence in frames. Shape: (B,).
rtf_metrics: Dictionary containing real-time factor and timing metrics.
cross_attention_maps: Optional cross-attention visualization maps.
List of numpy arrays, one per batch item. Only populated if
return_cross_attn_probs=True.
headwise_cross_attention_maps: Optional per-head cross-attention maps.
Only populated if return_cross_attn_probs=True and
compute_all_heads_attn_maps=True.
"""
predicted_audio: torch.Tensor
predicted_audio_lens: torch.Tensor
predicted_codes: torch.Tensor
predicted_codes_lens: torch.Tensor
rtf_metrics: Dict[str, Any]
cross_attention_maps: Optional[List[Any]] = None
headwise_cross_attention_maps: Optional[List[Any]] = None
def worker_init_fn(worker_id):
# For mp.set_start_method("spawn", force=True)
# The dataset class should be picklable, so we initialize non-picklable objects here
logging.info(f"Worker {worker_id} initializing...")
worker_info = get_worker_info()
dataset = worker_info.dataset # Get the dataset instance in this worker
tokenizer = setup_tokenizers(dataset.tokenizer_config, mode=dataset.dataset_type)
dataset.text_tokenizer = tokenizer
class MagpieTTSModel(ModelPT):
"""
Magpie-TTS Model Base Class used for training a TTS model that can generate audio codes from transcript and a context
audio/text
Supports multiple model types:
- multi_encoder_context_tts: Transcript and context audio go to different encoders. Transcript encoding feeds to
layers given by cfg.model.transcript_decoder_layers and the context encoding feeds into the layers given by
context_decoder_layers .Also supports text context which gets encoded by the same encoder as context audio.
Only one of context audio or contex text is supported.
- decoder_context_tts: Text goes into the encoder; context & target audio go to the decoder. Also supports text
context. Supports fixed sized context so we set context_duration_min and context_duration_max to the same
value (5 seconds). Text context, which is usually shorter than number of codec frames of 5 second of audio, is
padded to the max context duration in this model.
- decoder_ce: Same as decoder_context_tts except there is a small neural network between the context tensors and
the decoder input.
"""
def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
self.world_size = 1
if trainer is not None:
self.world_size = trainer.num_nodes * trainer.num_devices
# load codec, disable loading of loss modules not needed during inference
codec_model_path = cfg.get('codecmodel_path')
if codec_model_path.startswith('nvidia/'):
codec_model = AudioCodecModel.from_pretrained(codec_model_path)
else:
codec_model_cfg = AudioCodecModel.restore_from(codec_model_path, return_config=True)
if "use_scl_loss" in codec_model_cfg:
codec_model_cfg.use_scl_loss = False
codec_model = AudioCodecModel.restore_from(
codec_model_path, strict=False, override_config_path=codec_model_cfg
)
self.sample_rate = codec_model.sample_rate
self.codec_model_samples_per_frame = codec_model.samples_per_frame
# del codec discriminator to free memory
del codec_model.discriminator
# When using FSQ tokens, the codebook structure can be changed at any time.
# An FSQ definition can be provided in `vector_quantizer` config to train with a codebook structure
# that is different than in the audio codec checkpoint.
vector_quantizer = cfg.get('vector_quantizer')
if vector_quantizer is not None:
vector_quantizer = instantiate(vector_quantizer)
num_audio_codebooks = vector_quantizer.num_codebooks
codebook_size = vector_quantizer.codebook_size
codec_converter = VectorQuantizerIndexConverter(
vector_quantizer_original=codec_model.vector_quantizer,
vector_quantizer_new=vector_quantizer,
)
data_num_audio_codebooks = codec_model.vector_quantizer.num_codebooks
else:
num_audio_codebooks = codec_model.num_codebooks
data_num_audio_codebooks = num_audio_codebooks
codebook_size = codec_model.codebook_size
codec_converter = None
# The dataloader needs to know the number of codebooks that the context codes were stored in
# In the case where there are no context codes saved, and there is no context audio (in the text context path),
# We create a dummy context code tensor that is only [context_BOS, context_EOS] that is repeated for
# data_num_audio_codebooks
self.data_num_audio_codebooks = data_num_audio_codebooks
self.num_audio_codebooks = num_audio_codebooks
self.codebook_size = codebook_size
# Our codebooks start with actual audio codec tokens, followed by special tokens.
# The `forced_*` options are for backward compatibility for models trained with older code.
get_token_index = partial(SpecialAudioToken.get_index, base_codebook_size=self.codebook_size)
self.audio_bos_id = cfg.get('forced_audio_bos_id', get_token_index(SpecialAudioToken.AUDIO_BOS))
self.audio_eos_id = cfg.get('forced_audio_eos_id', get_token_index(SpecialAudioToken.AUDIO_EOS))
self.context_audio_bos_id = cfg.get(
'forced_context_audio_bos_id', get_token_index(SpecialAudioToken.AUDIO_CONTEXT_BOS)
)
self.context_audio_eos_id = cfg.get(
'forced_context_audio_eos_id', get_token_index(SpecialAudioToken.AUDIO_CONTEXT_EOS)
)
self.mask_token_id = cfg.get('forced_mask_token_id', get_token_index(SpecialAudioToken.MASK_TOKEN))
self.num_all_tokens_per_codebook = cfg.get(
'forced_num_all_tokens_per_codebook', self.codebook_size + len(SpecialAudioToken)
)
self.use_bpe_char_tokenizer = cfg.get('use_bpe_char_tokenizer', False)
# The frame stacking factor controls how many consecutive frames are processed together by the base decoder
# (and then refined into individual frames by the local transformer). A frame stacking factor of 1 means no
# frame stacking. We have a separate embedding table for each of the stacked frames, e.g. for frame stacking
# factor of 3, the entries of codebook 0 appear 3 times in the embedding table.
self.frame_stacking_factor = cfg.get('frame_stacking_factor', 1)
assert 'downsample_factor' not in cfg, '`downsample_factor` is deprecated, use `frame_stacking_factor` instead'
# Setup tokenizer
if hasattr(cfg, 'text_tokenizer'):
# For backward compatibility for English-only models
with open_dict(cfg):
cfg.text_tokenizers = {"english_phoneme": cfg.text_tokenizer}
del cfg['text_tokenizer']
self.use_text_conditioning_encoder = cfg.get('use_text_conditioning_encoder', False)
# Using google-t5/t5-small as default text conditioning tokenizer for backward compatibility.
self.text_conditioning_tokenizer_name = cfg.get('text_conditioning_tokenizer_name', None)
self.legacy_text_conditioning = cfg.get('legacy_text_conditioning', False)
if self.legacy_text_conditioning:
if self.text_conditioning_tokenizer_name is None:
self.text_conditioning_tokenizer_name = "google-t5/t5-small"
tokenizer_target = "AutoTokenizer"
if self.text_conditioning_tokenizer_name == "google-t5/t5-small":
tokenizer_target = "T5Tokenizer"
with open_dict(cfg):
cfg.text_tokenizers[self.text_conditioning_tokenizer_name] = {
'_target_': tokenizer_target,
'pretrained_model': self.text_conditioning_tokenizer_name,
}
elif self.text_conditioning_tokenizer_name is None:
# If no text_conditioning_tokenizer_name is specified, use the first one as default
# For text context tokenization
self.text_conditioning_tokenizer_name = list(cfg.text_tokenizers.keys())[0]
# TODO @xueyang: both tokenizers are only used to get some token ids. We
# should kill them to save a small amount of mem resources since dataloader will initialize them
# again after the worker processes are spawned.
self.tokenizer = setup_tokenizers(
all_tokenizers_config=cfg.text_tokenizers,
mode='train',
)
num_tokens_tokenizer = len(self.tokenizer.tokens)
if self.legacy_text_conditioning:
# Text context tokens are not a part of the the regular transcript embedding table in legacy models
num_tokens_tokenizer -= self.tokenizer.num_tokens_per_tokenizer[self.text_conditioning_tokenizer_name]
num_tokens = num_tokens_tokenizer + 2 # +2 for BOS and EOS
self.bos_id = num_tokens - 2
self.eos_id = num_tokens - 1
self.model_type = cfg.get('model_type', None)
self.pad_context_text_to_max_duration = self.model_type in ['decoder_context_tts', 'decoder_ce']
self.use_kv_cache_for_inference = cfg.get('use_kv_cache_for_inference', False)
# Below args (text_context_remapping_json, text_context_remapping_prob) are
# for combining multiple context_texts into a single one during training.
# Eg. if we want to treat Emma_neutral and Emma_conversational as one speaker,
# we can create an override dict {'Emma_neutral' : 'Emma', 'Emma_conversational' : 'Emma'}
# This dict is saved in a json file given by cfg.model.text_context_remapping_json
# If we want to preserve both behaviours i.e (Emma_neutral, Emma_conversational) and just (Emma)
# we can do this mapping with a probability during training, as specified by text_context_remapping_prob
self.text_context_remapping = None
text_context_remapping_json = cfg.get('text_context_remapping_json', None)
self.text_context_remapping_prob = cfg.get('text_context_remapping_prob', 0.0)
if text_context_remapping_json is not None:
with open(text_context_remapping_json, 'r') as f:
self.text_context_remapping = json.load(f)
super().__init__(cfg=cfg, trainer=trainer)
if self.legacy_text_conditioning:
tc_tokenizer = self.tokenizer.tokenizers[self.text_conditioning_tokenizer_name]
self.context_text_embedding = nn.Embedding(tc_tokenizer.vocab_size, cfg.embedding_dim)
# This needs to happen after super().__init__()
self._codec_model = codec_model
self._codec_model.freeze() # Lightning does requires_grad = False and self.eval()
self._codec_converter = codec_converter
audio_embeddings = []
for _ in range(self.num_audio_codebooks * self.frame_stacking_factor):
audio_embeddings.append(nn.Embedding(self.num_all_tokens_per_codebook, cfg.embedding_dim))
self.audio_embeddings = nn.ModuleList(audio_embeddings)
if self.use_bpe_char_tokenizer:
# BPE char tokenizer
assert len(self.tokenizer.tokenizers) == 1, "BPE char tokenizer should only be used with one tokenizer"
tokenizer_name = self.tokenizer.tokenizer_names[0]
tokenizer = self.tokenizer.tokenizers[tokenizer_name]
subword_vocab = tokenizer.get_vocab()
# special tokens will be stored as it is in the char_vocab
# Each special token will only be mapped to one char id
special_vocab = {
'<BOS>': self.bos_id,
'<EOS>': self.eos_id,
}
self.cas_encoder = CharAwareSubwordEncoder(
d_embed=cfg.embedding_dim,
llm_tokenizer_vocab=subword_vocab,
subword_padding_idx=self.tokenizer.pad,
special_vocab=special_vocab,
)
else:
# Regular text embedding
self.text_embedding = nn.Embedding(num_tokens, cfg.embedding_dim)
self.encoder = transformer_2501.Transformer(**dict(cfg.encoder))
self.decoder = transformer_2501.Transformer(**dict(cfg.decoder))
self.final_proj = nn.Linear(
cfg.decoder.d_model,
self.num_audio_codebooks * self.num_all_tokens_per_codebook * self.frame_stacking_factor,
)
self.local_transformer_type = LocalTransformerType(cfg.get('local_transformer_type', 'none').lower())
logging.info(f"Local transformer type: {self.local_transformer_type}")
if self.local_transformer_type != LocalTransformerType.NO_LT:
local_transformer_hidden_dim = cfg.get('local_transformer_hidden_dim', 256)
if local_transformer_hidden_dim != cfg.decoder.d_model:
self.local_transformer_in_projection = nn.Linear(cfg.decoder.d_model, local_transformer_hidden_dim)
else:
self.local_transformer_in_projection = nn.Identity()
self.local_transformer = transformer_2501.Transformer(
n_layers=self.cfg.get('local_transformer_n_layers', 2),
d_model=local_transformer_hidden_dim,
d_ffn=local_transformer_hidden_dim * 4,
sa_n_heads=self.cfg.get('local_transformer_n_heads', 1),
kernel_size=1,
is_causal=self.local_transformer_type == LocalTransformerType.AR,
max_length_causal_mask=self.frame_stacking_factor * self.num_audio_codebooks + 2,
use_learnable_pos_emb=True,
)
local_transformer_out_projections = []
for _ in range(self.num_audio_codebooks * self.frame_stacking_factor):
# Have a separate projection layer for each codebook, to distinguish between them
local_transformer_out_projections.append(
nn.Linear(local_transformer_hidden_dim, self.num_all_tokens_per_codebook)
)
self.local_transformer_out_projections = nn.ModuleList(local_transformer_out_projections)
if cfg.get('use_alignment_encoder', False):
self.alignment_encoder = AlignmentEncoder(
n_mel_channels=cfg.embedding_dim,
n_text_channels=cfg.embedding_dim,
dist_type="cosine",
temperature=15.0,
)
if self.model_type == 'multi_encoder_context_tts':
logging.warning(f"The multi_encoder_context_tts model type for {self} is deprecated.")
# Transcript and context audio/text go to different encoders.
# Output of the encoders goes to the decoder through the cross-attention layers
self.transcript_decoder_layers = cfg.get('transcript_decoder_layers', [3, 4, 5, 6, 7, 8])
self.context_decoder_layers = cfg.get(
'context_decoder_layers', [0, 1, 2, 9, 10, 11]
) # For backward compatibility
multi_encoder_mapping = [None for _ in range(self.decoder.n_layers)]
for layer in self.transcript_decoder_layers:
multi_encoder_mapping[layer] = 0 # 0 means text goes to this layer, 1 means context goes to this layer
for layer in self.context_decoder_layers:
multi_encoder_mapping[layer] = 1
self.multi_encoder_mapping = multi_encoder_mapping
self.context_encoder = transformer_2501.Transformer(**dict(cfg.context_encoder))
elif self.model_type == 'decoder_context_tts':
# Context audio/text goes directly to the decoder (before the target audio codes)
self.transcript_decoder_layers = [
idx for idx in range(self.decoder.n_layers)
] # All layers are used for text
elif self.model_type == 'decoder_ce':
# Similar to decoder_context_tts, but we use context encoder
# Decoder gets output from context encoder instead of raw context tokens embeddings
self.context_encoder = transformer_2501.Transformer(**dict(cfg.context_encoder))
self.transcript_decoder_layers = [
idx for idx in range(cfg.decoder.n_layers)
] # All layers are used for text
# Register buffers for baked context embedding (initially None/empty)
# These will be populated when loading a checkpoint with baked embedding
self.register_buffer('baked_context_embedding', None)
self.register_buffer('baked_context_embedding_len', None)
else:
raise ValueError(f"Unsupported model type {self.model_type}")
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
self.alignment_loss_scale = cfg.get('alignment_loss_scale', 0.0)
self.alignment_encoder_loss_scale = cfg.get('alignment_encoder_loss_scale', 0.0)
if self.alignment_loss_scale > 0.0:
self.alignment_loss = ForwardSumLoss(loss_scale=self.alignment_loss_scale)
if self.alignment_encoder_loss_scale > 0.0:
self.alignment_encoder_loss = ForwardSumLoss(loss_scale=self.alignment_encoder_loss_scale)
# Define cfg parameters into self parameters
self.prior_end_step = self.cfg.prior_end_step
self.prior_scaledown_start_step = self.cfg.prior_scaledown_start_step
self.indefinite_prior_prob = self.cfg.get('indefinite_prior_prob', 0.0)
self.ctc_prior_layer_ids = self.cfg.get('ctc_prior_layer_ids', self.transcript_decoder_layers)
self.cfg_unconditional_prob = self.cfg.get('cfg_unconditional_prob', 0.0)
self.decoder_input_dropout_prob = self.cfg.get('decoder_input_dropout_prob', 0.0)
self.binarize_attn_method = self.cfg.get('binarize_attn_method', 'argmax')
self.binarize_repeat_audio_factor = self.cfg.get('binarize_repeat_audio_factor', 2)
self.prior_future_decay = self.cfg.get('prior_future_decay', 1.0)
self.prior_past_decay = self.cfg.get('prior_past_decay', 1.0)
self.binarized_prior_epsilon = self.cfg.get('binarized_prior_epsilon', 0.0)
self.prior_future_context = self.cfg.get('prior_future_context', 1)
self.prior_past_context = self.cfg.get('prior_past_context', 1)
self.binarize_prior_after_step = self.cfg.get('binarize_prior_after_step', 0)
self.codebook_loss_scale = self.cfg.get('codebook_loss_scale', 1.0)
self.local_transformer_loss_scale = self.cfg.get('local_transformer_loss_scale', 1.0)
self.use_alignment_encoder = self.cfg.get('use_alignment_encoder', False)
self.use_prior_for_aligner = self.cfg.get('use_prior_for_aligner', False)
self.aligner_encoder_train_steps = self.cfg.get('aligner_encoder_train_steps', float('inf'))
self.dec_random_input_max = self.cfg.get('dec_random_input_max', self.num_all_tokens_per_codebook)
# Configuration validity checks
self.check_frame_stacking_config_validity()
def state_dict(self, destination=None, prefix='', keep_vars=False):
"""
Only used for saving checkpoints. On save, we remove _speaker_verification_model and _codec_model
from the checkpoint. The codec model is saved in a separate checkpoint.
_speaker_verification_model is only included in older checkpoints with the older single_encoder_sv_tts
model_type that is no longer supported and can likely be removed in a future version.
If the model has a baked context embedding, the context_encoder weights are also excluded
since they are no longer needed for inference.
"""
if hasattr(self, '_no_state_dict') and self._no_state_dict:
return {}
# Don't save the speaker verification and codec model in the state dict
state_dict = super().state_dict(destination, prefix, keep_vars)
keys_substrings_to_exclude = ['_speaker_verification_model', '_codec_model']
# If we have a baked context embedding, exclude context_encoder weights
if self.has_baked_context_embedding:
keys_substrings_to_exclude.append('context_encoder')
for key in list(state_dict.keys()):
if any([substring in key for substring in keys_substrings_to_exclude]):
del state_dict[key]
return state_dict
def check_frame_stacking_config_validity(self):
"""
Check if the configuration is compatible with frame stacking.
"""
if self.frame_stacking_factor > 1:
# The settings below are not supported with frame stacking.
# Some of them may work - but they have not been tested.
# disallow alignment encoder
if self.use_alignment_encoder:
raise ValueError("Alignment encoder is not supported for frame stacking")
# disallow alignment loss
if self.alignment_loss_scale > 0.0:
raise ValueError("Alignment loss is not supported for frame stacking")
# disallow training prior
if self.cfg.prior_scaling_factor is not None and self.cfg.prior_scaling_factor > 0:
raise ValueError("Training-time attention prior is not supported for frame stacking")
# disallow text conditioning
if self.use_text_conditioning_encoder:
raise ValueError("Text conditioning is not supported for frame stacking")
@property
def has_baked_context_embedding(self) -> bool:
"""Check if the model has a baked context embedding.
Returns:
True if baked_context_embedding buffer is set, not None, and has elements.
"""
return (
self.model_type == 'decoder_ce'
and hasattr(self, 'baked_context_embedding')
and self.baked_context_embedding is not None
and self.baked_context_embedding.numel() > 0
)
def update_ckpt(self, state_dict):
"""
Backward compatibility for checkpoints saved with old model names.
"""
new_state_dict = {}
for key in state_dict.keys():
if 't5_encoder' in key:
new_key = key.replace('t5_encoder', 'encoder')
new_state_dict[new_key] = state_dict[key]
elif 't5_decoder' in key:
new_key = key.replace('t5_decoder', 'decoder')
new_state_dict[new_key] = state_dict[key]
else:
new_state_dict[key] = state_dict[key]
return new_state_dict
def load_state_dict(self, state_dict, strict=True):
"""
Modify load_state_dict so that we don't restore weights to _speaker_verification_model and _codec_model when
strict is True.
When strict is False, we can call pytorch's load_state_dict.
When strict is True, we loop through all parameters and rename them to enable loading.
_speaker_verification_model is only included in older checkpoints with the older single_encoder_sv_tts
model_type that is no longer supported and can likely be removed in a future version.
Also handles loading baked context embeddings. If the checkpoint contains baked_context_embedding,
context_encoder weights are not expected to be present.
"""
state_dict = self.update_ckpt(state_dict)
# Check if checkpoint has baked context embedding
has_baked_embedding_in_ckpt = (
'baked_context_embedding' in state_dict and state_dict['baked_context_embedding'] is not None
)
# Load baked embedding buffers if present
if has_baked_embedding_in_ckpt:
self.baked_context_embedding = state_dict['baked_context_embedding']
self.baked_context_embedding_len = state_dict['baked_context_embedding_len']
logging.info(
f"Loaded baked context embedding with shape {self.baked_context_embedding.shape}, "
f"length {self.baked_context_embedding_len.item()}"
)
if not strict:
super().load_state_dict(state_dict, strict=False)
# Build list of modules to skip
modules_to_skip = [
'_speaker_verification_model',
'_codec_model',
'_reference_model',
'eval_asr_model',
'eval_speaker_verification_model',
'whisper_model',
'squim_objective_model',
]
# Skip context_encoder if checkpoint has baked embedding (weights won't be in checkpoint)
if has_baked_embedding_in_ckpt:
modules_to_skip.append('context_encoder')
for name, child in self.named_children():
if name in modules_to_skip:
continue
if any(param.numel() > 0 for param in child.parameters()):
# If the module has parameters, we want to change the default mapping so that the state_dict gets
# loaded.
# Ex: state_dict[encoder.position_embeddings.weight] -> new_state_dict[position_embeddings.weight]
new_state_dict = {}
for key in state_dict.keys():
name_with_dot = f"{name}."
if key.startswith(name_with_dot):
new_state_dict[key[len(name_with_dot) :]] = state_dict[key]
child.load_state_dict(new_state_dict)
def audio_to_codes(self, audio, audio_len, audio_type='target'):
# audio: (B, T)
# audio_len: (B,)
if audio_type == 'target':
audio_eos_id = self.audio_eos_id
audio_bos_id = self.audio_bos_id
elif audio_type == 'context':
audio_eos_id = self.context_audio_eos_id
audio_bos_id = self.context_audio_bos_id
else:
raise ValueError(f"Received audio_type of {audio_type}. Must be `target` or `context`")
self._codec_model.eval()
with torch.no_grad(), torch.autocast(device_type=audio.device.type, dtype=torch.float32):
codes, codes_len = self._codec_model.encode(audio=audio, audio_len=audio_len)
if self._codec_converter is not None:
codes = self._codec_converter.convert_original_to_new(audio_tokens=codes, audio_lens=codes_len)
# Add a timestep to begining and end of codes tensor
bos_tensor = torch.full(
(codes.size(0), codes.size(1), 1), audio_bos_id, dtype=codes.dtype, device=codes.device
)
# pad at the end to make room for the EOS token; the EOS token's actual position
# varies per batch element depending on each element's length.
pad_tensor = torch.full(
(codes.size(0), codes.size(1), 1), 0, dtype=codes.dtype, device=codes.device
) # 0 is the padding token in the audio codebook
codes = torch.cat([bos_tensor, codes, pad_tensor], dim=-1)
# codes: (B, C, T')
# codes_len: (B,)
for idx in range(codes.size(0)):
codes[idx, :, codes_len[idx] + 1] = audio_eos_id
codes_len = codes_len + 2 # +1 for bos and +1 for eos
return codes.long(), codes_len.long()
def codes_to_audio(self, codes, codes_len):
# codes: (B, C, T')
# codes_len: (B,)
self._codec_model.eval()
with torch.no_grad(), torch.autocast(device_type=codes.device.type, dtype=torch.float32):
# Make a copy to avoid modifying the original tensor if it's used elsewhere
codes_copy = codes.clone()
# Replace eos and bos tokens with padding in the copied tensor
codes_copy[codes == self.audio_bos_id] = 0 # zero is the padding token
codes_copy[codes == self.audio_eos_id] = 0
# Pass the modified integer token IDs
if self._codec_converter is not None:
codes_copy = self._codec_converter.convert_new_to_original(
audio_tokens=codes_copy, audio_lens=codes_len
)
audio, audio_len = self._codec_model.decode(tokens=codes_copy, tokens_len=codes_len)
# audio: (B, T)
# audio_len: (B,)
return audio, audio_len
def embed_audio_tokens(self, audio_tokens):
B, C, T = audio_tokens.shape
audio_embedding = None
for i in range(self.frame_stacking_factor):
for c in range(C):
tokens = audio_tokens[:, c, i :: self.frame_stacking_factor]
embedding = self.audio_embeddings[c + i * C](tokens)
if audio_embedding is None:
audio_embedding = embedding
else:
audio_embedding += embedding
audio_embedding = audio_embedding / (C * self.frame_stacking_factor)
return audio_embedding
def compute_local_transformer_logits(self, dec_out, audio_codes_target, targets_offset_by_one=False):
"""
Predicts the logits for all codebooks using the local transformer. Used in both autoregressive (AR) and MaskGit (MG) modes.
This function is used in training and validation, not inference/sampling.
The sequence layout is slightly different between AR and MG modes, as shown in the diagram below,
(using an 8-codebook setup as an example):
+------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+
| AR target | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | none |
| codebook | | | | | | | | | |
+------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+
| MG target | none | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
| codebook | | | | | | | | | |
+------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+
| input | Magpie | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
| codebook | latent | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK |
+------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+
| seq. index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
+------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+
dec_out: (B, T', E)
audio_codes_target: (B, C, T')
targets_offset_by_one: bool, if False, the target for index 0 is codebook 0, for index 1 is codebook 1, etc. (autoregressive)
if True, the target for index 1 is codebook 0, for index 2 is codebook 1, etc. (MaskGit)
"""
C = self.num_audio_codebooks
dec_out_all = dec_out.reshape(-1, dec_out.size(-1)) # (B*T', E)
local_transformer_input = [dec_out_all]
# Build the teacher-forced input to the LT.
for fs_index in range(self.frame_stacking_factor):
for codebook_num in range(C):
# Collect ground truth codes for the current codebook and frame stack index combintation.
codes = audio_codes_target[:, codebook_num, fs_index :: self.frame_stacking_factor] # (B, T')
# Individual timesteps are independently handled by the LT fold time into the batch dimension.
codes = codes.reshape(-1) # (B*T',)
# Embed the codes
codebook_embedding = self.audio_embeddings[codebook_num + fs_index * C](codes) # (B*T', E)
local_transformer_input.append(codebook_embedding)
# Stack the input codes along dimension 1 (codebooks). This is the dimension along which the LT predicts iteratively.
local_transformer_input = torch.stack(local_transformer_input, dim=1) # (B*T', C+1, E)
local_transformer_input = self.local_transformer_in_projection(local_transformer_input) # (B*T', C+1, 128)
_mask = torch.ones(
local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device
)
local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B*T', C+1, E)
if not targets_offset_by_one:
# for autoregressive local transformer the target for index 0 is codebook 0, for index 1 is codebook 1, etc.
local_transformer_output = local_transformer_output[:, :-1, :] # (B*T', C, E)
else:
# for MaskGit the target for index **1** is codebook 0, for index 2 is codebook 1, etc.
local_transformer_output = local_transformer_output[:, 1:, :] # (B*T', C, E)
all_code_logits = []
for fs_index in range(self.frame_stacking_factor):
for codebook_num in range(audio_codes_target.size(1)):
# Using a separate projection layer for each codebook (to distinguish between them)
# Checked the time - this loop is not taking much time (compared to the local transformer forward pass)
codebook_logits = self.local_transformer_out_projections[codebook_num + fs_index * C](
local_transformer_output[:, codebook_num + fs_index * C, :]
) # (B*T', num_all_tokens_per_codebook)
all_code_logits.append(codebook_logits)
all_code_logits = torch.cat(
all_code_logits, dim=1
) # (B*T'/frame_stacking_factor, num_codebooks * num_all_tokens_per_codebook * frame_stacking_factor)
all_code_logits = all_code_logits.view(
audio_codes_target.size(0), audio_codes_target.size(2) // self.frame_stacking_factor, -1
) # (B, T'/frame_stacking_factor, C * num_all_tokens_per_codebook * frame_stacking_factor)
return all_code_logits
def maskgit_create_random_mask(self, codes):
"""
Creates a mask where True indicates the positions that should be replaced with a MASK_TOKEN.
"""
# Codes: (B, C, T)
B, C, T = codes.shape
# get a uniform random vector uniformly sampled from [0,1) ## Todo does it need to be inclusive on the right?
rand_values = torch.rand(B, T, device=codes.device)
# apply the cosine schedule
frac_masked = cosine_schedule(rand_values)
# how many positions to mask
n_masked = torch.ceil(frac_masked * C).long() # B,T
# The code further below is the vectorized version of this:
# for b in range(B):
# for t in range(T):
# if n_masked[b,t] > 0:
# # get a random permutation of the codebook indices
# perm = torch.randperm(C)
# # mask the top n_masked positions
# mask[b, perm[:n_masked[b,t]], t] = True
#
# Create random permutations
random_permutations = torch.argsort(torch.rand(B, C, T, device=codes.device), dim=1) # (B, C, T)
# Create a mask tensor where each position indicates if it should be masked
mask_indices = torch.arange(C, device=codes.device).view(1, C, 1)
mask = mask_indices < n_masked.view(B, 1, T) # (B, C, T)
# Apply the random permutations to the mask
mask = torch.gather(mask, 1, random_permutations)
return mask # (B, C, T)
def maskgit_apply_random_mask(self, codes):
# Randomly replaces some codes with the MASK_TOKEN with a proportion following the cosine schedule.
# Codes: (B, C, T)
mask = self.maskgit_create_random_mask(codes)
# replace some tokens with MASK_TOKEN
codes_with_mask = torch.where(mask, self.mask_token_id, codes)
return codes_with_mask, mask
def compute_loss(self, logits, audio_codes, audio_codes_lens, mask_tokens_mask=None, frame_stacking_factor=1):
"""
Computes the audio codebook loss. Used by
(1) The main Magpie-TTS transformer
(2) The local transformer, for both autoregressive and MaskGit methods
logits: (B, T', num_codebooks * num_tokens_per_codebook)
audio_codes: (B, C, T')
audio_codes_lens: (B,)
mask_tokens_mask: (B, C, T') True for tokens that were replaced with the MASK_TOKEN and should
therefore be the only ones included in the loss computation (for MaskGit).
frame_stacking_factor: int, the stacking factor used in the model
"""
loss_mask = get_mask_from_lengths(audio_codes_lens, pad_to_factor=frame_stacking_factor)
if mask_tokens_mask is not None:
# For MaskGit we only compute loss for the masked tokens.
# *Both* conditions must be true:
# 1. the token is masked
# 2. the token is not padding
loss_mask = loss_mask.unsqueeze(1) * mask_tokens_mask
if not loss_mask.any():
# Without this we were very rarely getting NaNs in the loss
logging.warning("No tokens valid were found in compute_loss()!")
return torch.tensor(0.0, device=loss_mask.device), loss_mask
else:
# repeat loss mask for each codebook to simplify code below
loss_mask = loss_mask.unsqueeze(1).repeat(1, audio_codes.size(1), 1)
total_codebook_loss = None
for fs_index in range(frame_stacking_factor):
for codebook in range(audio_codes.size(1)):
si = (codebook + self.num_audio_codebooks * fs_index) * self.num_all_tokens_per_codebook
ei = si + self.num_all_tokens_per_codebook
codebook_logits = logits[:, :, si:ei] # (B, T', num_tokens_per_codebook)
codebook_targets = audio_codes[:, codebook, fs_index::frame_stacking_factor] # (B, T')
codebook_loss = self.cross_entropy_loss(
codebook_logits.permute(0, 2, 1), codebook_targets # (B, num_tokens_per_codebook, T')
) # (B, T')
codebook_loss_mask = loss_mask[:, codebook, fs_index::frame_stacking_factor]
codebook_loss = codebook_loss * codebook_loss_mask
if codebook_loss_mask.sum() == 0:
logging.warning(f"Loss mask for codebook {codebook} is all zeros, global_step: {self.global_step}")
continue
codebook_loss = codebook_loss.sum() / codebook_loss_mask.sum()
if total_codebook_loss is None:
total_codebook_loss = codebook_loss
else:
total_codebook_loss = total_codebook_loss + codebook_loss
total_codebook_loss = total_codebook_loss / (audio_codes.size(1) * frame_stacking_factor)
return total_codebook_loss, loss_mask
def forward(self, dec_input_embedded, dec_input_mask, cond, cond_mask, attn_prior, multi_encoder_mapping):
decoder_out = self.decoder(
dec_input_embedded,
dec_input_mask,
cond=cond,
cond_mask=cond_mask,
attn_prior=attn_prior,
multi_encoder_mapping=multi_encoder_mapping,
)
attn_probabilities = decoder_out['attn_probabilities']
all_code_logits = self.final_proj(decoder_out['output']) # (B, T', num_codebooks * num_tokens_per_codebook)
return all_code_logits, attn_probabilities, decoder_out['output']
def logits_to_audio_codes(self, all_code_logits, audio_codes_lens):
# all_code_logits: (B, T', num_codebooks * num_tokens_per_codebook)
# audio_codes_lens: (B,)
all_preds = [[] for _ in range(self.frame_stacking_factor)]
for fs_index in range(self.frame_stacking_factor):
for idx in range(self.num_audio_codebooks):
si = (idx + self.num_audio_codebooks * fs_index) * self.num_all_tokens_per_codebook
ei = si + self.num_all_tokens_per_codebook
codebook_logits = all_code_logits[:, :, si:ei]
codebook_probs = torch.softmax(codebook_logits, dim=-1) # (B, T', num_tokens_per_codebook)
# argmax to get the tokens
codebook_preds = torch.argmax(codebook_probs, dim=-1) # (B, T')
all_preds[fs_index].append(codebook_preds)
all_preds = [
torch.stack(p, dim=1) for p in all_preds
] # list of `frame_stacking_factor`` elements of shape (B,C,T) each
all_preds = torch.stack(all_preds, dim=-1) # B, C, T, frame_stacking_factor
# undo the frame stacking
all_preds = all_preds.reshape(all_preds.size(0), all_preds.size(1), -1) # B, C, T*frame_stacking_factor
pred_max_len = all_preds.size(2)
real_max_len = audio_codes_lens.max()
assert (pred_max_len - real_max_len) < self.frame_stacking_factor
# trim padding introduced for frame stacking
all_preds = all_preds[:, :, :real_max_len]
audio_mask = get_mask_from_lengths(audio_codes_lens)
all_preds = all_preds * audio_mask.unsqueeze(1)
return all_preds
def visualize_codes(self, codes, mask_id=2020, frame_stacking_rate=2):
"""
Visualize codes for analysis purposes
codes: (B, C)
"""
def code_to_str(code):
if code == mask_id:
return "M "
else:
return f"{code:04d} "
B, C = codes.shape
if B > 1:
logging.debug("Warning: visualizing only first batch element")
codes = codes.clone().detach().cpu().numpy()[0]
codes = [code_to_str(c) for c in codes]
output_str = ""
for i, c in enumerate(codes):
if (i) % (C / frame_stacking_rate) == 0:
output_str += "|timestep| "
output_str += c
logging.debug(output_str)
def clear_forbidden_logits(self, logits: torch.Tensor, forbid_audio_eos: bool = False) -> torch.Tensor:
"""
Sets logits of forbidden tokens to `-inf` so they will never be sampled.
Specifically, we forbid sampling of all special tokens except AUDIO_EOS
which is allowed by default.
Args:
logits: (B, C, num_audio_tokens_per_codebook)
forbid_audio_eos (bool, optional): If True, also forbid AUDIO_EOS tokens
from being sampled. Default: False.
"""
logits[
:,
:,
SpecialAudioToken.get_forbidden_tokens(self.codebook_size, forbid_audio_eos=forbid_audio_eos),
] = float('-inf')
return logits
def local_transformer_sample_maskgit(
self,
dec_output: torch.Tensor,
temperature: float = 0.7,
topk: int = 80,
unfinished_items: Dict[int, bool] = {},
finished_items: Dict[int, bool] = {},
use_cfg: bool = False,
cfg_scale: float = 1.0,
n_steps: int = 3,
noise_scale: float = 0.0,
fixed_schedule: Optional[List[int]] = None,
dynamic_cfg_scale: bool = False,
sampling_type: Optional[str] = None,
forbid_audio_eos: bool = False,
) -> torch.Tensor:
"""
Sample audio codes for the current timestep using MaskGit-like iterative
prediction with the local transformer. If frame-stacking is enabled, the
codes for all frames in the stack are sampled, treated as one long sequence.
The MaskGit process starts with all positions masked and iteratively unmasks the
most confident positions over multiple steps. By "masked" we mean that a
dedicated MASK token is used (as opposed to attention masking). The LT in this
case is a non-causal transformer decoder. At each step the model predicts all
positions at once. Of those predictions, a subset of the most confident
previously-masked positions is kept and unmasked in the next step. The number of
positions that are unmasked at each step is determined by the unmasking
schedule. We support a cosine schedule and a fixed schedule provided by the
user.
Uses multinomial sampling with temperature, top-k, and classifier-free guidance (CFG).
Special handling:
* forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled
* forces / forbids EOS for finished / unfinished items respectively
* optionally, globally forbids audio EOS for all items in the batch.
This is useful early in the generation process.
* supports different unmasking methods, see `sampling_type` argument for details.
Args:
dec_output (torch.Tensor): Decoder output tensor with shape (B, E) where B is batch size
and E is primary decoder's embedding dimension.
temperature (float, optional): Sampling temperature
topk (int, optional): Number of top-probability tokens to consider in sampling.
unfinished_items (dict, optional): Dictionary containing indices of batch
items that we are confident have not completed generation. For these items, audio EOS
sampling is forbidden.
finished_items (dict, optional): Dictionary containing indices of batch
items that we are confident are completed. For these items, audio EOS sampling
is forced.
use_cfg (bool, optional): Whether to use classifier-free guidance. If True, expects batch size
to be doubled with conditional and unconditional outputs from the primary decoder.
cfg_scale (float, optional): Scale factor for classifier-free guidance. Only used if use_cfg=True.
n_steps (int, optional): Number of iterative refinement steps for MaskGit sampling.
noise_scale (float, optional): Scale factor for noise to add to confidence scores
during sampling (experimental).
fixed_schedule (list, optional): Fixed schedule for number of tokens to unmask at each step.
If None, uses cosine schedule.
dynamic_cfg_scale (bool, optional): Whether to dynamically adjust CFG scale during
sampling (experimental).
sampling_type (str, optional): Type of sampling strategy. Options are:
["default", "causal", "purity_causal", "purity_default"].
* Purity refers to "purity sampling" from https://arxiv.org/abs/2304.01515. If "purity"
is not specified, confidence sampling is used as in the original MaskGit paper.
* "default"/"causal": Controls the order of unmasking across frames when frame-stacking is enabled.
If "causal" is specified, frames are unmasked in causal order. "default"
doesn't impose any constraints on the unmasking order.
forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire
batch.
Returns:
torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor)
"""
# dec_output: (B, E)
device = dec_output.device
# disable KV cache since our transformer is not causal
self.local_transformer.reset_cache(use_cache=False)
dec_output = dec_output.unsqueeze(1) # (B, 1, E)
local_transformer_input_init = self.local_transformer_in_projection(
dec_output
) # (B, 1, D) where D is the dimension of the local transformer
codebook_seq_len = self.num_audio_codebooks * self.frame_stacking_factor
B = dec_output.size(0)
min_confidence = 0
# this needs to be large enough that unmasked items will always remain unmasked (even after noise addition)
# Setting it smaller could allow "regret", i.e. re-masking a codebook that was previously unmasked; we might want to try that
max_confidence = 5
confidences = min_confidence * torch.ones(B, codebook_seq_len, device=device)
# initialize to all masked
codes = self.mask_token_id * torch.ones((B, codebook_seq_len), device=device, dtype=torch.long)
sampled_codes = codes.clone()
topk_indices = None
if fixed_schedule is not None:
n_steps = len(fixed_schedule)
for step in range(n_steps):
# how far along we are in the unmasking process
progress = step / n_steps
# get mask fraction
frac_masked = cosine_schedule(torch.tensor(progress))
if sampling_type == "causal" or sampling_type == "purity_causal":
frac_masked = torch.ones_like(frac_masked) * (1.0 - progress)
# how many codebooks to mask
if fixed_schedule is None:
n_masked = torch.ceil(codebook_seq_len * frac_masked).long()
else:
n_masked = codebook_seq_len - fixed_schedule[step]
n_unmasked = codebook_seq_len - n_masked
if (
sampling_type == "causal" or sampling_type == "purity_causal"
): # and n_unmasked <= self.num_audio_codebooks:
# force second frame not to be unmasked
n_frames_to_allow = int(np.floor(progress * self.frame_stacking_factor + 1))
confidences[:, n_frames_to_allow * self.num_audio_codebooks :] = (
min_confidence - 1
) # only tested for frame_stacking_factor=2
# pick top-confidence codebooks up to n_unmasked
_, topk_indices = torch.topk(confidences, k=n_unmasked, dim=1)
if use_cfg:
actual_batch_size = topk_indices.size(0) // 2
assert (
topk_indices[actual_batch_size:] == topk_indices[:actual_batch_size]
).all(), "Topk indices are not the same for conditional and unconditional codes"
# replace masks of the top-k confident codebooks with the codes that were sampled for them
unmasked_codes = torch.gather(sampled_codes, dim=1, index=topk_indices)
codes.scatter_(dim=1, index=topk_indices, src=unmasked_codes)
# build transformer input
local_transformer_input = local_transformer_input_init
for codebook_num in range(codebook_seq_len):
next_local_transformer_input = self.audio_embeddings[codebook_num](codes[:, codebook_num]).unsqueeze(
1
) # (B, 1, 768)
next_local_transformer_input = self.local_transformer_in_projection(
next_local_transformer_input
) # (B, 1, d_local)
local_transformer_input = torch.cat(
[local_transformer_input, next_local_transformer_input], dim=1
) # (B, codebook_num+1, d_local)
# run transformer
_mask = torch.ones(B, codebook_seq_len + 1, device=device)
local_transformer_output = self.local_transformer(local_transformer_input, _mask)[
'output'
] # (B, C+1, d_local)
# get logits
logits = []
for codebook_num in range(codebook_seq_len):
# The `codebook_num+1` is to drop first position which corresponds to the magpie latent
codebook_logits = self.local_transformer_out_projections[codebook_num](
local_transformer_output[:, codebook_num + 1, :]
) # (B, num_audio_tokens_per_codebook)
logits.append(codebook_logits)
logits = torch.stack(logits, dim=1) # (B, C*frame_stacking_factor, num_audio_tokens_per_codebook)
# apply CFG
if use_cfg:
actual_batch_size = logits.size(0) // 2
conditional_logits = logits[:actual_batch_size]
unconditional_logits = logits[actual_batch_size:]
if not dynamic_cfg_scale:
current_cfg_scale = cfg_scale
else:
# gradually increase the scale until mid point through sampling, then reduce it again
progress = step / (n_steps - 1)
# interp = -abs(progress-0.5)+0.5 # increase from 0..1 in the interval from start to midpoint and then go back to zero
# interp = 1.0 - progress # decrease from 1 to 0
interp = progress # gradually increase from 0 to 1
current_cfg_scale = (cfg_scale - 1) * interp + 1.0 # 1.0 --> cfg_scale --> 1.0
cfg_logits = current_cfg_scale * conditional_logits + (1.0 - current_cfg_scale) * unconditional_logits
logits[:actual_batch_size] = cfg_logits
# Disallow generation of special tokens
logits = self.clear_forbidden_logits(logits, forbid_audio_eos=forbid_audio_eos)
# handle unfinished and finished items
for item_idx in unfinished_items:
logits[item_idx, self.audio_eos_id] = float('-inf')
for item_idx in finished_items:
logits[item_idx, :, :] = float('-inf')
logits[item_idx, :, self.audio_eos_id] = 0.0
# sample with top-k
logits_topk = torch.topk(logits, topk, dim=-1)[0] # (B, C, topk)
indices_to_remove = logits < logits_topk[:, :, -1].unsqueeze(-1) # (B, C, num_audio_tokens_per_codebook)
logits_rescored = logits.clone()
logits_rescored[indices_to_remove] = float('-inf')
probs = torch.softmax(logits_rescored / temperature, dim=-1) # (B, C, num_audio_tokens_per_codebook)
sampled_codes = torch.multinomial(probs.view(B * codebook_seq_len, -1), 1).view(B, codebook_seq_len)
if use_cfg:
sampled_codes[actual_batch_size:] = sampled_codes[:actual_batch_size]
probs[actual_batch_size:] = probs[:actual_batch_size]
if sampling_type != "purity_causal" and sampling_type != "purity_default":
confidences = torch.gather(probs, dim=2, index=sampled_codes.unsqueeze(-1)).squeeze(-1)
else:
# use the max probability across all tokens for each codebook as the confidence for each codebook; known as "purity sampling"
confidences = probs.max(dim=2)[0]
# replace entries in sampled_codes with previously unmasked codebooks
sampled_codes.scatter_(dim=1, index=topk_indices, src=unmasked_codes)
# add noise to confidences (as in token-critic paper, https://arxiv.org/abs/2209.04439)
if noise_scale > 0.0:
# get noise from uniform distribution in the interval [-0.5, 0.5), scale it by `noise_scale`,
# and anneal it to 0 as we approach the end of the unmasking process
noise = (
(torch.rand_like(confidences) - 0.5) * noise_scale * (1 - (step + 2) / n_steps)
) # the +2 makes sure that by the last iteration the noise is exactly 0
confidences += noise
# the conditional and unconditional get different noise and must be fixed to be the same again
confidences[actual_batch_size:] = confidences[:actual_batch_size]
confidence_eps = 0.1
assert (
confidences.max() + confidence_eps < max_confidence
), f"Predicted confidence is approaching max_confidence: {confidences.max()}"
# for unmasked codebooks, set confidence to max so that they will remain unmasked
confidences.scatter_(
index=topk_indices, dim=1, src=max_confidence * torch.ones_like(topk_indices, dtype=torch.float)
)
codes = sampled_codes
assert not (
codes == self.mask_token_id
).any(), "Codes contain mask tokens after completion of MaskGit sampling"
# break stacked groups of frames into individual frames
codes = codes.reshape(B, self.frame_stacking_factor, self.num_audio_codebooks).permute(
0, 2, 1
) # B, C, frame_stacking_factor
if use_cfg:
# drop unconditional codes
codes = codes[:actual_batch_size]
return codes
def local_transformer_sample_autoregressive(
self,
dec_output: torch.Tensor,
temperature: float = 0.7,
topk: int = 80,
unfinished_items: Dict[int, bool] = {},
finished_items: Dict[int, bool] = {},
use_cfg: bool = False,
cfg_scale: float = 1.0,
use_kv_cache: bool = True,
forbid_audio_eos: bool = False,
) -> torch.Tensor:
"""
Sample audio codes autoregressively across codebooks using the local
transformer. Uses multinomial sampling with temperature, top-k, and
classifier-free guidance (CFG).
The sequence is initialized with the primary decoder's hidden output as the only
input and is gradually extended a code for one codebook at a time, appending the
sampled code as input sequence for the next step. At the last step the sequence
is `num_codebooks` long. If frame stacking is enabled, codes for all frames in
the stack are sampled as one long sequence and the final sequence length is
`num_codebooks * frame_stacking_factor` codes long.
Special handling:
* forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled
* forces / forbids EOS for finished / unfinished items respectively
* optionally, globally forbids audio EOS (useful early in the generation process)
Args:
dec_output (torch.Tensor): Decoder output tensor with shape (B, E) where B is batch size
and E is primary decoder's embedding dimension.
temperature (float, optional): Sampling temperature.
topk (int, optional): Number of top-probability tokens to consider in sampling.
unfinished_items (dict, optional): Dictionary containing indices of batch
items that we are confident have not completed generation. For these items, audio EOS
sampling is forbidden.
finished_items (dict, optional): Dictionary containing indices of batch
items that we are confident are completed. For these items, audio EOS sampling
is forced.
use_cfg (bool, optional): Whether to use classifier-free guidance. If True, expects batch size
to be doubled with conditional and unconditional outputs from the primary decoder.
cfg_scale (float, optional): Scale factor for classifier-free guidance. Only used if use_cfg=True.
use_kv_cache (bool, optional): Whether to use key-value caching in the transformer.
forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire
batch.
Returns:
torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor)
where B is batch size (or actual_batch_size if use_cfg=True).
"""
self.local_transformer.reset_cache(use_cache=use_kv_cache)
dec_output = dec_output.unsqueeze(1) # (B, 1, E)
local_transformer_input = self.local_transformer_in_projection(dec_output) # (B, 1, 128)
all_preds = []
for codebook_num in range(self.num_audio_codebooks * self.frame_stacking_factor):
_mask = torch.ones(
local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device
)
local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B, T, 128)
codebook_logits = self.local_transformer_out_projections[codebook_num](
local_transformer_output[:, -1, :]
) # (B, num_all_tokens_per_codebook)
if use_cfg:
actual_batch_size = codebook_logits.size(0) // 2
conditional_logits = codebook_logits[:actual_batch_size]
unconditional_logits = codebook_logits[actual_batch_size:]
cfg_logits = cfg_scale * conditional_logits + (1.0 - cfg_scale) * unconditional_logits
codebook_logits[:actual_batch_size] = cfg_logits
for item_idx in unfinished_items:
codebook_logits[item_idx, self.audio_eos_id] = float('-inf')
for item_idx in finished_items:
codebook_logits[item_idx, :] = float('-inf')
codebook_logits[item_idx, self.audio_eos_id] = 0.0
# Disallow generation of special tokens
codebook_logits = self.clear_forbidden_logits(
codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos
).squeeze(1)
codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk)
indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze(
-1
) # (B, num_tokens_per_codebook)
codebook_logits_rescored = codebook_logits.clone()
codebook_logits_rescored[indices_to_remove] = float('-inf')
codebook_probs = torch.softmax(
codebook_logits_rescored / temperature, dim=-1
) # (B, num_tokens_per_codebook)
codebook_preds = torch.multinomial(codebook_probs, 1) # (B, 1)
if use_cfg:
codebook_preds[actual_batch_size:] = codebook_preds[:actual_batch_size]
all_preds.append(codebook_preds)
next_local_transformer_input = self.audio_embeddings[codebook_num](codebook_preds.squeeze(-1)).unsqueeze(
1
) # (B, 1, 128)
next_local_transformer_input = self.local_transformer_in_projection(
next_local_transformer_input
) # (B, 1, 128)
local_transformer_input = torch.cat(
[local_transformer_input, next_local_transformer_input], dim=1
) # (B, T+1, 128)
all_preds = torch.cat(all_preds, dim=1).long() # (B, num_codebooks * frame_stacking_factor)
all_preds = all_preds.reshape(-1, self.frame_stacking_factor, self.num_audio_codebooks).permute(
0, 2, 1
) # (B, num_codebooks, frame_stacking_factor)
if use_cfg:
all_preds = all_preds[:actual_batch_size]
return all_preds
def sample_codes_from_logits(
self,
all_code_logits_t: torch.Tensor,
temperature: float = 0.7,
topk: int = 80,
unfinished_items: Dict[int, bool] = {},
finished_items: Dict[int, bool] = {},
forbid_audio_eos: bool = False,
) -> torch.Tensor:
"""
Sample codes for all codebooks at a given timestep. Uses multinomial sampling
with temperature and top-k. If frame stacking is on (i.e. `frame_stacking_factor
> 1`), this function will sample across the entire frame stack.
Special handling:
* forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled
* forces / forbids EOS for finished / unfinished items respectively
* optionally, globally forbids audio EOS (useful early in the generation process)
Args:
all_code_logits_t (torch.Tensor): Logits at a given timestep with shape
(B, num_tokens_per_codebook * num_codebooks * frame_stacking_factor)
temperature (float, optional): Sampling temperature
topk (int, optional): Number of top-probability tokens to consider in sampling.
unfinished_items (dict, optional): Dictionary containing indices of batch
items that we are confident have not completed generation. For these items, audio EOS
sampling is forbidden.
finished_items (dict, optional): Dictionary containing indices of batch
items that we are confident are completed. For these items, audio EOS sampling
is forced.
forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire
batch.
Returns:
torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor).
"""
all_preds = [[] for _ in range(self.frame_stacking_factor)]
for fs_index in range(self.frame_stacking_factor):
for idx in range(self.num_audio_codebooks):
si = (idx + self.num_audio_codebooks * fs_index) * self.num_all_tokens_per_codebook
ei = si + self.num_all_tokens_per_codebook
codebook_logits = all_code_logits_t[:, si:ei] # (B, num_tokens_per_codebook)
for item_idx in unfinished_items:
codebook_logits[item_idx, self.audio_eos_id] = float('-inf')
for item_idx in finished_items:
codebook_logits[item_idx, :] = float('-inf')
codebook_logits[item_idx, self.audio_eos_id] = 0.0
# Disallow generation of special tokens
codebook_logits = self.clear_forbidden_logits(
codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos
).squeeze(1)
codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk)
indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze(
-1
) # (B, num_tokens_per_codebook)
codebook_logits_rescored = codebook_logits.clone()
codebook_logits_rescored[indices_to_remove] = float('-inf')
codebook_probs = torch.softmax(
codebook_logits_rescored / temperature, dim=-1
) # (B, num_tokens_per_codebook)
codebook_preds = torch.multinomial(codebook_probs, 1) # (B, 1)
all_preds[fs_index].append(codebook_preds)
all_preds = [
torch.cat(ds_preds, dim=1).long() for ds_preds in all_preds
] # list of `frame_stacking_factor` elements, each of shape (B, num_codebooks)
all_preds = torch.stack(all_preds, dim=2) # (B, num_codebooks, frame_stacking_factor)
return all_preds
def log_attention_probs(self, attention_prob_matrix, audio_codes_lens, text_lens, prefix="", dec_context_size=0):
# attention_prob_matrix List of (B, C, audio_timesteps, text_timesteps)
wandb_images_log = {}
with torch.no_grad():
attention_prob_matrix = torch.cat(attention_prob_matrix, dim=1) # (B, C, audio_timesteps, text_timesteps)
attention_prob_matrix_mean = attention_prob_matrix.mean(dim=1) # (B, audio_timesteps, text_timesteps)
for logger in self.loggers:
is_wandb = isinstance(logger, WandbLogger)
is_tb = isinstance(logger, TensorBoardLogger)
if not is_wandb and not is_tb:
raise ValueError(
f"Invalid logger type for image logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported."
)
wandb_images_log[f"Image/{prefix}/attention_matrix"] = list()
for idx in range(min(3, attention_prob_matrix_mean.size(0))):
item_attn_matrix = attention_prob_matrix_mean[idx][
dec_context_size : dec_context_size + audio_codes_lens[idx], : text_lens[idx]
]
item_attn_matrix = item_attn_matrix.detach().cpu().numpy()
img_np = plot_alignment_to_numpy(item_attn_matrix.T)
if is_wandb:
wandb_images_log[f"Image/{prefix}/attention_matrix"].append(
wandb.Image(img_np, caption=f"Example_{idx}")
)
if is_tb:
logger.experiment.add_image(
f'{prefix}/attention_matrix/Example_{idx}',
img_np,
global_step=self.global_step,
dataformats="HWC",
)
return wandb_images_log
def log_val_audio_example(
self,
logits,
target_audio_codes,
audio_codes_lens_target,
context_audio_codes=None,
context_audio_codes_lens=None,
):
wandb_audio_log = {}
pred_audio_codes = self.logits_to_audio_codes(logits, audio_codes_lens_target)
pred_audio, pred_audio_lens = self.codes_to_audio(pred_audio_codes, audio_codes_lens_target)
target_audio, target_audio_lens = self.codes_to_audio(target_audio_codes, audio_codes_lens_target)
context_audio, context_audio_lens = None, None
if context_audio_codes is not None and context_audio_codes.shape[2] > 3:
# > 3 ensures, it is a valid context audio tensor (and not dummy tensor used in text context)
context_audio, context_audio_lens = self.codes_to_audio(context_audio_codes, context_audio_codes_lens)
for logger in self.loggers:
is_wandb = isinstance(logger, WandbLogger)
is_tb = isinstance(logger, TensorBoardLogger)
if not is_wandb and not is_tb:
raise ValueError(
f"Invalid logger type for audio logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported."
)
for idx in range(min(3, pred_audio.size(0))):
pred_audio_np = pred_audio[idx].float().detach().cpu().numpy()
target_audio_np = target_audio[idx].float().detach().cpu().numpy()
pred_audio_np = pred_audio_np[: pred_audio_lens[idx]]
target_audio_np = target_audio_np[: target_audio_lens[idx]]
context_audio_np = None
if context_audio is not None:
context_audio_np = context_audio[idx].float().detach().cpu().numpy()
context_audio_np = context_audio_np[: context_audio_lens[idx]]
if is_wandb:
wandb_audio_log[f"Audio/Example_{idx}"] = list()
if context_audio_np is not None:
wandb_audio_log[f"Audio/Example_{idx}"].append(
wandb.Audio(context_audio_np, sample_rate=self.sample_rate, caption="context")
)
wandb_audio_log[f"Audio/Example_{idx}"].append(
wandb.Audio(pred_audio_np, sample_rate=self.sample_rate, caption="prediction")
)
wandb_audio_log[f"Audio/Example_{idx}"].append(
wandb.Audio(target_audio_np, sample_rate=self.sample_rate, caption="target")
)
if is_tb:
if context_audio_np is not None:
logger.experiment.add_audio(
f'Example_{idx}/context',
context_audio_np,
global_step=self.global_step,
sample_rate=self.sample_rate,
)
logger.experiment.add_audio(
f'Example_{idx}/prediction',
pred_audio_np,
global_step=self.global_step,
sample_rate=self.sample_rate,
)
logger.experiment.add_audio(
f'Example_{idx}/target',
target_audio_np,
global_step=self.global_step,
sample_rate=self.sample_rate,
)
return wandb_audio_log
def scale_prior(self, prior, global_step):
if prior is None:
return None
if global_step < self.prior_scaledown_start_step:
return prior
elif global_step >= self.prior_end_step:
if random.random() < self.indefinite_prior_prob:
print("Using Prior")
return prior
else:
print("Not using Prior")
return None
else:
with torch.no_grad():
# Interpolate between all ones and the prior
residual = 1.0 - prior
new_prior = prior + (
residual
* (global_step - self.prior_scaledown_start_step)
/ (self.prior_end_step - self.prior_scaledown_start_step)
)
return new_prior
def embed_text(self, text, text_mask):
if self.use_bpe_char_tokenizer:
text_embedded = self.cas_encoder(text, subword_mask=text_mask)
else:
text_embedded = self.text_embedding(text)
return text_embedded
def compute_alignment_loss(self, attention_scores, text_lens, audio_lens, dec_context_size=0):
# attention scores: List of (B, C, audio_timesteps, text_timesteps)
attention_scores_combined = torch.cat(attention_scores, dim=1) # (B, C, audio_timesteps, text_timesteps)
attention_scores_mean = attention_scores_combined.mean(
dim=1, keepdim=True
) # (B, 1, audio_timesteps, text_timesteps)
attention_scores_mean = attention_scores_mean[
:, :, dec_context_size:, :
] # Remove the context audio embeddings from the attention scores
alignment_loss = self.alignment_loss(
attn_logprob=attention_scores_mean, in_lens=text_lens, out_lens=audio_lens
)
return alignment_loss
def pad_audio_codes(self, audio_codes: torch.Tensor, frame_stacking_factor: int = 1, pad_token: int = 0):
"""
Pads the time dimension of the audio codes to a multiple of the frame stacking factor.
Args:
audio_codes (torch.Tensor): B, C, T
frame_stacking_factor (int): The factor that frames will be stacked by.
pad_token (int): The token ID to pad with.
Returns:
B, C, T_padded
"""
T = audio_codes.size(2)
T_padded = int(np.ceil(T / frame_stacking_factor) * frame_stacking_factor)
if T_padded > T:
padding = pad_token * torch.ones(
audio_codes.size(0),
audio_codes.size(1),
T_padded - T,
device=audio_codes.device,
dtype=audio_codes.dtype,
)
audio_codes = torch.cat([audio_codes, padding], dim=2)
return audio_codes
def embed_context_text(self, context_text_tokens):
if self.legacy_text_conditioning:
context_text_tokens = (
context_text_tokens - self.tokenizer.tokenizer_offsets[self.text_conditioning_tokenizer_name]
)
context_text_embedded = self.context_text_embedding(context_text_tokens) # (B, L, E)
else:
context_text_embedded = self.text_embedding(context_text_tokens) # (B, L, E)
return context_text_embedded
def prepare_context_tensors(self, batch):
dec_context_size = 0
additional_decoder_input = None
additional_decoder_mask = None
context_audio_codes = None
context_audio_codes_lens = None
_attn_prior = None
attn_prior = None
cond = None
cond_mask = None
multi_encoder_mapping = None
text = None
text_lens = None
# self.model_type must be one of [multi_encoder_context_tts, decoder_context_tts, decoder_ce]
text = batch['text']
text_lens = batch['text_lens']
text_mask = get_mask_from_lengths(text_lens) # (B, T)
text_embedded = self.embed_text(text, text_mask) # (B, T, E)
text_encoder_out = self.encoder(text_embedded, text_mask, cond=None, cond_mask=None)['output'] # (B, T, E)
_attn_prior = batch.get('align_prior_matrix', None)
_attn_prior = self.scale_prior(_attn_prior, self.global_step)
if self.model_type in ['multi_encoder_context_tts', 'decoder_context_tts', 'decoder_ce']:
if 'context_audio_codes' in batch:
context_audio_codes = batch['context_audio_codes']
context_audio_codes_lens = batch['context_audio_codes_lens']
if self._codec_converter is not None:
context_audio_codes = self._codec_converter.convert_original_to_new(
audio_tokens=context_audio_codes, audio_lens=context_audio_codes_lens
).long()
else:
context_audio_codes, context_audio_codes_lens = self.audio_to_codes(
batch['context_audio'], batch['context_audio_lens'], audio_type='context'
)
context_audio_codes = self.pad_audio_codes(context_audio_codes, self.frame_stacking_factor, pad_token=0)
context_audio_embedded = self.embed_audio_tokens(context_audio_codes) # (B, T/frame_stacking_factor, E)
if self.use_text_conditioning_encoder:
context_text_tokens = batch['context_text_tokens']
context_text_lens = batch['context_text_tokens_lens']
context_text_embedded = self.embed_context_text(context_text_tokens) # (B, L, E)
# Pad context_audio_embedded or context_text_embedded so that they have same number of timesteps
if context_audio_embedded.size(1) < context_text_embedded.size(1):
padding = torch.zeros(
context_audio_embedded.size(0),
context_text_embedded.size(1) - context_audio_embedded.size(1),
context_audio_embedded.size(2),
device=context_audio_embedded.device,
)
context_audio_embedded = torch.cat([context_audio_embedded, padding], dim=1)
elif context_audio_embedded.size(1) > context_text_embedded.size(1):
padding = torch.zeros(
context_text_embedded.size(0),
context_audio_embedded.size(1) - context_text_embedded.size(1),
context_text_embedded.size(2),
device=context_text_embedded.device,
)
context_text_embedded = torch.cat([context_text_embedded, padding], dim=1) # (B, T, E)
has_text_context = batch['has_text_context'].unsqueeze(-1).unsqueeze(-1).float() # (B, 1, 1)
context_input_embedded = (
has_text_context * context_text_embedded + (1 - has_text_context) * context_audio_embedded
)
context_input_lens = (
batch['has_text_context'].float() * context_text_lens
+ (1 - batch['has_text_context'].float()) * context_audio_codes_lens
) # (B,)
else:
context_input_embedded = context_audio_embedded
context_input_lens = context_audio_codes_lens
context_input_lens = torch.ceil(context_input_lens / self.frame_stacking_factor).to(
context_input_lens.dtype
)
context_mask = get_mask_from_lengths(context_input_lens)
if self.model_type == 'multi_encoder_context_tts':
context_embeddings = self.context_encoder(
context_input_embedded, context_mask, cond=None, cond_mask=None
)['output']
cond = [text_encoder_out, context_embeddings]
cond_mask = [text_mask, context_mask]
multi_encoder_mapping = self.multi_encoder_mapping
attn_prior = [_attn_prior, None]
elif self.model_type in ['decoder_context_tts', 'decoder_ce']:
context_embeddings = None # Address CodeQL
if self.model_type == 'decoder_context_tts':
context_embeddings = context_input_embedded
elif self.model_type == 'decoder_ce':
# Check for baked context embedding first
if self.has_baked_context_embedding:
# self.baked_context_embedding is a fixed context embedding that is baked into the model.
# This is used when we do not want users to generate speech with context audio or context text.
# This is done to disable zero-shot inference. Users can only generate speech in 1 voice chosen
# by the model development team.
batch_size = text.size(0)
# Expand baked embedding to batch size: (T, E) -> (B, T, E)
context_embeddings = self.baked_context_embedding.unsqueeze(0).expand(batch_size, -1, -1)
# Create context mask from baked length
context_input_lens = (
self.baked_context_embedding_len.unsqueeze(0).expand(batch_size).to(text.device)
)
context_mask = get_mask_from_lengths(context_input_lens)
else:
context_embeddings = self.context_encoder(
context_input_embedded, context_mask, cond=None, cond_mask=None
)['output']
dec_context_size = context_mask.size(1)
attn_prior = _attn_prior
if attn_prior is not None:
# B, audio_timesteps, text_timesteps
padding_zeros = torch.zeros(
attn_prior.size(0), dec_context_size, attn_prior.size(2), device=attn_prior.device
)
attn_prior = torch.cat([padding_zeros, attn_prior], dim=1)
cond = text_encoder_out
cond_mask = text_mask
multi_encoder_mapping = None
additional_decoder_input = context_embeddings
additional_decoder_mask = context_mask
else:
raise ValueError(f"Unsupported model type {self.model_type}")
if attn_prior is not None and self.ctc_prior_layer_ids is not None:
# Convert prior to a list of tensors, one for each layer
# Set None for layers not in ctc_prior_layer_ids
if self.model_type == 'multi_encoder_context_tts':
text_attn_prior = [
attn_prior[0] if layer_idx in self.ctc_prior_layer_ids else None
for layer_idx in range(self.decoder.n_layers)
]
attn_prior = [text_attn_prior, attn_prior[1]]
else:
attn_prior = [
attn_prior if layer_idx in self.ctc_prior_layer_ids else None
for layer_idx in range(self.decoder.n_layers)
]
return {
'beta_binomial_attn_prior': batch.get('align_prior_matrix', None),
'text_encoder_out': text_encoder_out,
'cond': cond,
'cond_mask': cond_mask,
'attn_prior': attn_prior,
'prior_used': _attn_prior is not None,
'multi_encoder_mapping': multi_encoder_mapping,
'additional_decoder_input': additional_decoder_input,
'additional_decoder_mask': additional_decoder_mask,
'dec_context_size': dec_context_size,
'text': text,
'text_embedded': text_embedded,
'text_mask': text_mask,
'text_lens': text_lens,
'context_audio_codes': context_audio_codes,
'context_audio_codes_lens': context_audio_codes_lens,
}
def replace_beta_binomial_prior_with_binarized(self, attn_prior, aligner_attn_hard):
# aligner_attn_hard B, audio_timesteps, text_timesteps
if self.model_type == 'multi_encoder_context_tts':
text_attn_prior = attn_prior[0]
else:
text_attn_prior = attn_prior
assert text_attn_prior is not None, "Prior is None"
if isinstance(text_attn_prior, list):
# Layer wise prior
prior_updated = False
for idx, prior in enumerate(text_attn_prior):
if prior is not None:
text_attn_prior[idx][:, -aligner_attn_hard.size(1) :, :] = aligner_attn_hard
prior_updated = True
assert prior_updated, "Did not find any prior to update"
else:
# Same prior for all layers
text_attn_prior[:, -aligner_attn_hard.size(1) :, :] = aligner_attn_hard
if self.model_type == 'multi_encoder_context_tts':
attn_prior[0] = text_attn_prior
else:
attn_prior = text_attn_prior
return attn_prior
def get_binarized_prior_matrix(self, aligner_attn_soft, audio_lens, text_lens):
# aligner_attn_soft B, 1, audio_timesteps, text_timesteps
if self.binarize_attn_method == 'nemo_binarize':
logging.debug("Binarizing attention using nemo_binarize")
binarize_repeat_audio_factor = self.binarize_repeat_audio_factor
aligner_attn_soft_repeated = aligner_attn_soft.repeat_interleave(
binarize_repeat_audio_factor, dim=2
) # B, 1, 2*audio_timesteps, text_timesteps
aligner_attn_hard = binarize_attention_parallel(
aligner_attn_soft_repeated, text_lens, audio_lens * binarize_repeat_audio_factor
).squeeze(
1
) # B, 2*audio_timesteps, text_timesteps
aligner_attn_hard = aligner_attn_hard[:, ::2, :] # B, audio_timesteps, text_timesteps
elif self.binarize_attn_method == 'argmax':
logging.debug("Binarizing attention using argmax")
aligner_attn_hard = torch.argmax(aligner_attn_soft.squeeze(1), dim=-1)
aligner_attn_hard = torch.nn.functional.one_hot(
aligner_attn_hard, num_classes=aligner_attn_soft.size(-1)
).float()
else:
raise ValueError(
f"self.binarize_attn_method '{self.binarize_attn_method}' must be one of 'nemo_binarize' or 'argmax'."
)
aligner_attn_hard_wider = aligner_attn_hard + self.binarized_prior_epsilon
for future_timestep in range(self.prior_future_context):
decay_factor = self.prior_future_decay ** (future_timestep + 1)
aligner_attn_hard_wider[:, :, future_timestep + 1 :] += (
decay_factor * aligner_attn_hard[:, :, : -(future_timestep + 1)]
)
for past_timestep in range(self.prior_past_context):
decay_factor = self.prior_past_decay ** (past_timestep + 1)
aligner_attn_hard_wider[:, :, : -past_timestep - 1] += (
decay_factor * aligner_attn_hard[:, :, past_timestep + 1 :]
)
aligner_attn_hard_wider = torch.clamp(aligner_attn_hard_wider, 0.0, 1.0)
return aligner_attn_hard_wider
def prepare_dummy_cond_for_cfg(self, cond, cond_mask, additional_decoder_input, additional_dec_mask):
dummy_additional_decoder_input = None
dummy_additional_dec_mask = None
if additional_decoder_input is not None:
dummy_additional_decoder_input = torch.zeros_like(additional_decoder_input)
# all ones mask means dont ignore any timesteps (so that it is consistent with usual decoder mask)
dummy_additional_dec_mask = torch.ones_like(additional_dec_mask)
if isinstance(cond, list):
# multi encoder conditioning
dummy_cond = [torch.zeros_like(cond_item) for cond_item in cond]
attn_prior = [None for _ in cond]
dummy_mask = []
for mask_item in cond_mask:
# ignore all timesteps except the first one
mask = torch.zeros_like(mask_item)
mask[:, 0] = 1 # Make first timestep all zeros
dummy_mask.append(mask)
elif isinstance(cond, torch.Tensor):
# single encoder conditioning
dummy_cond = torch.zeros_like(cond)
dummy_mask = torch.zeros_like(cond_mask)
dummy_mask[:, 0] = 1 # ignore all timesteps except the first one
attn_prior = None
else:
raise ValueError(f"Unsupported type for cond {type(cond)}")
return dummy_cond, dummy_mask, dummy_additional_decoder_input, dummy_additional_dec_mask, attn_prior
def process_batch(self, batch, mode="train"):
context_tensors = self.prepare_context_tensors(batch)
disable_alignment_loss = False
if 'audio_codes' not in batch:
audio_codes, audio_codes_lens = self.audio_to_codes(batch['audio'], batch['audio_lens'])
else:
audio_codes = batch['audio_codes']
audio_codes_lens = batch['audio_codes_lens']
if self._codec_converter:
audio_codes = self._codec_converter.convert_original_to_new(
audio_tokens=audio_codes, audio_lens=audio_codes_lens
).long()
if self.frame_stacking_factor > 1:
# repeat the BOS token to frame_stacking_factor times. This is necessary since at inference
# we need to start autoregressive generation from a full stack indicating BOS.
# TODO: @rfejgin: this assert might be slow due to GPU/CPU sync
assert (audio_codes[:, :, 0] == self.audio_bos_id).all(), "Audio codes do not start with BOS token"
audio_codes = torch.cat(
[
torch.full(
(audio_codes.size(0), audio_codes.size(1), self.frame_stacking_factor - 1),
self.audio_bos_id,
device=audio_codes.device,
dtype=audio_codes.dtype,
),
audio_codes,
],
dim=2,
)
audio_codes_lens += self.frame_stacking_factor - 1 # account for BOS repeat
audio_codes = self.pad_audio_codes(audio_codes, self.frame_stacking_factor, pad_token=0)
# Note: if a tensor lacks the `_unstacked` suffix, it can be assumed to to be in the frame-stacked domain
# drop last (stacked) frame since it is not part of *input*
audio_codes_input_unstacked = audio_codes[:, :, : -self.frame_stacking_factor] # B, C, T'
# drop first (stacked) frame which contains BOS token(s) which are not part of *target*
audio_codes_target_unstacked = audio_codes[:, :, self.frame_stacking_factor :]
audio_codes_lens_input_unstacked = audio_codes_lens - 1 # don't count EOS for input
audio_codes_lens_target_unstacked = audio_codes_lens - self.frame_stacking_factor # don't count BOS for target
audio_codes_lens_input = torch.floor(audio_codes_lens_input_unstacked / self.frame_stacking_factor).long()
audio_codes_embedded_all = self.embed_audio_tokens(
audio_codes
) # (B, T, E) # Computing this to be use in the alignment encoder
audio_codes_embedded = audio_codes_embedded_all[
:, :-1, :
] # (B, T', E) Input to the decoder; this is already in the frame-stacked domain, hence the -1 (not `frame_stacking_factor`)
audio_codes_mask = get_mask_from_lengths(audio_codes_lens_input)
use_cfg = (self.cfg_unconditional_prob > 0.0) and (mode == "train") and (context_tensors['cond'] is not None)
if use_cfg and torch.rand(1).item() < self.cfg_unconditional_prob:
cond, cond_mask, additional_decoder_input, additional_decoder_mask, attn_prior = (
self.prepare_dummy_cond_for_cfg(
context_tensors['cond'],
context_tensors['cond_mask'],
context_tensors['additional_decoder_input'],
context_tensors['additional_decoder_mask'],
)
)
disable_alignment_loss = True
else:
cond = context_tensors['cond']
cond_mask = context_tensors['cond_mask']
additional_decoder_input = context_tensors['additional_decoder_input']
additional_decoder_mask = context_tensors['additional_decoder_mask']
attn_prior = context_tensors['attn_prior']
if mode == "train" and self.decoder_input_dropout_prob > 0.0 and torch.rand(1).item() < 0.5:
# For some batches (half of them), replace decoder_input_dropout_prob of the timesteps with random tokens
max_codebook_val = self.dec_random_input_max
# @pneekhara: Keeping dec_random_input_max configurable since num_all_tokens_per_codebook usually has padding tokens
# which can cause errors when doing codes_to_audio for audio_codes_input. We are not currently calling codes_to_audio on
# audio_codes_input so should not matter if we don't supply dec_random_input_max.
random_audio_tokens = torch.randint(
0, max_codebook_val, audio_codes_input_unstacked.size(), device=audio_codes_input_unstacked.device
)
random_audio_tokens = random_audio_tokens * audio_codes_mask.unsqueeze(1)
dec_dropout_mask = (
torch.rand((1, 1, audio_codes_input_unstacked.size(2)), device=audio_codes_input_unstacked.device)
> self.decoder_input_dropout_prob
)
# timestep_mask is True for timesteps to be kept
audio_codes_input_unstacked = audio_codes_input_unstacked * dec_dropout_mask + random_audio_tokens * (
~dec_dropout_mask
)
audio_codes_embedded = self.embed_audio_tokens(audio_codes_input_unstacked) # (B, T', E)
if context_tensors['additional_decoder_input'] is not None:
dec_input_embedded = torch.cat([additional_decoder_input, audio_codes_embedded], dim=1)
dec_input_mask = torch.cat([additional_decoder_mask, audio_codes_mask], dim=1)
else:
dec_input_embedded = audio_codes_embedded
dec_input_mask = audio_codes_mask
aligner_encoder_loss = None
aligner_attn_soft = None
aligner_attn_hard = None
if self.use_alignment_encoder and not disable_alignment_loss:
aligner_prior = None
if self.use_prior_for_aligner:
aligner_prior = context_tensors['beta_binomial_attn_prior']
# Passing target audio embeddings to the alignment encoder
if self.global_step < self.aligner_encoder_train_steps:
aligner_attn_soft, aligner_attn_logprobs = self.alignment_encoder(
queries=audio_codes_embedded_all[:, 1:, :].permute(0, 2, 1), # B, E, T'
keys=context_tensors['text_encoder_out'].permute(0, 2, 1), # B, E, T
mask=~context_tensors['text_mask'].unsqueeze(-1),
attn_prior=aligner_prior,
)
aligner_encoder_loss = self.alignment_encoder_loss(
attn_logprob=aligner_attn_logprobs,
in_lens=context_tensors['text_lens'],
out_lens=audio_codes_lens_input,
)
else:
with torch.no_grad():
# Just get the attention matrix without computing the loss or gradients
aligner_attn_soft, aligner_attn_logprobs = self.alignment_encoder(
queries=audio_codes_embedded_all[:, 1:, :].permute(0, 2, 1), # B, E, T'
keys=context_tensors['text_encoder_out'].permute(0, 2, 1), # B, E, T
mask=~context_tensors['text_mask'].unsqueeze(-1),
attn_prior=aligner_prior,
)
with torch.no_grad():
aligner_attn_hard = self.get_binarized_prior_matrix(
aligner_attn_soft, audio_codes_lens_input, context_tensors['text_lens']
)
if (self.global_step > self.binarize_prior_after_step) and context_tensors['prior_used']:
attn_prior = self.replace_beta_binomial_prior_with_binarized(attn_prior, aligner_attn_hard)
logits, attn_info, dec_out = self.forward(
dec_input_embedded=dec_input_embedded,
dec_input_mask=dec_input_mask,
cond=cond,
cond_mask=cond_mask,
attn_prior=attn_prior,
multi_encoder_mapping=context_tensors['multi_encoder_mapping'],
)
# logits: (B, T', num_codebooks * num_tokens_per_codebook)
# dec_out: (B, T', E)
dec_context_size = context_tensors['dec_context_size']
logits = logits[:, dec_context_size:, :] # Remove the context audio embeddings from the logits
# Codebook loss (parallel)
codebook_loss, loss_mask = self.compute_loss(
logits,
audio_codes_target_unstacked,
audio_codes_lens_target_unstacked,
frame_stacking_factor=self.frame_stacking_factor,
)
# Alignment loss
alignment_loss = None
if self.alignment_loss_scale > 0.0 and not disable_alignment_loss:
text_lens = context_tensors['text_lens']
cross_attention_scores = [
attn['cross_attn_probabilities'][1]
for layer_idx, attn in enumerate(attn_info)
if layer_idx in self.ctc_prior_layer_ids
]
alignment_loss = self.compute_alignment_loss(
cross_attention_scores, text_lens, audio_codes_lens_input, dec_context_size
)
loss = self.codebook_loss_scale * codebook_loss + alignment_loss
else:
loss = self.codebook_loss_scale * codebook_loss
# Local Transformer loss
local_transformer_loss = None
local_transformer_logits = None
if self.local_transformer_type != LocalTransformerType.NO_LT:
if self.local_transformer_type == LocalTransformerType.MASKGIT:
# Maskgit
# randomly replace some positions with MASK_TOKEN
audio_codes_masked, mask_tokens_mask = self.maskgit_apply_random_mask(audio_codes_target_unstacked)
# TODO @rfejgin: the very last position might be padding but the local transformer might look at it as part of
# of a pair where the first position is valid. Is this an issue?
local_transformer_logits = self.compute_local_transformer_logits(
dec_out[:, dec_context_size:, :], audio_codes_masked, targets_offset_by_one=True
)
local_transformer_loss, _ = self.compute_loss(
local_transformer_logits,
audio_codes_target_unstacked,
audio_codes_lens_target_unstacked,
mask_tokens_mask,
frame_stacking_factor=self.frame_stacking_factor,
)
else:
# Autoregressive
assert self.local_transformer_type == LocalTransformerType.AR, "Unexpected local transformer type"
local_transformer_logits = self.compute_local_transformer_logits(
dec_out[:, dec_context_size:, :], audio_codes_target_unstacked, targets_offset_by_one=False
)
local_transformer_loss, _ = self.compute_loss(
local_transformer_logits,
audio_codes_target_unstacked,
audio_codes_lens_target_unstacked,
None,
frame_stacking_factor=self.frame_stacking_factor,
)
loss = loss + self.local_transformer_loss_scale * local_transformer_loss
if aligner_encoder_loss is not None:
loss = loss + aligner_encoder_loss
return {
'logits': logits,
'attn_info': attn_info,
'loss': loss,
'codebook_loss': codebook_loss,
'local_transformer_loss': local_transformer_loss,
'local_transformer_logits': local_transformer_logits,
'loss_mask': loss_mask,
'alignment_loss': alignment_loss,
'aligner_encoder_loss': aligner_encoder_loss,
'audio_codes_target': audio_codes_target_unstacked,
'audio_codes_lens_target': audio_codes_lens_target_unstacked,
'text': context_tensors['text'],
'text_lens': context_tensors['text_lens'],
'context_audio_codes': context_tensors['context_audio_codes'],
'context_audio_codes_lens': context_tensors['context_audio_codes_lens'],
'dec_context_size': dec_context_size,
'aligner_attn_soft': aligner_attn_soft,
'aligner_attn_hard': aligner_attn_hard,
}
def training_step(self, batch, batch_idx):
batch_output = self.process_batch(batch)
loss = batch_output['loss']
codebook_loss = batch_output['codebook_loss']
self.log('train/codebook_loss', codebook_loss, prog_bar=True, sync_dist=True)
if self.cfg_unconditional_prob == 0.0:
# Only log alignment loss when not using cfg to avoid sync issues when
# alignment loss is None on some ranks
alignment_loss = batch_output['alignment_loss']
if alignment_loss is not None:
self.log('train/alignment_loss', alignment_loss, prog_bar=True, sync_dist=True)
self.log('train/loss', loss, prog_bar=True, sync_dist=True)
local_transformer_loss = batch_output['local_transformer_loss']
if local_transformer_loss is not None:
self.log('train/local_transformer_loss', local_transformer_loss, prog_bar=True, sync_dist=True)
# Log batch info
batch_size, text_token_max_len = batch["text"].shape
text_token_total_num = batch["text_lens"].sum()
batch_info_dict = {
"train/batch_size": batch_size,
"train/text_token_max_len": text_token_max_len,
"train/text_token_total_num_in_batch": text_token_total_num.item(),
"train/text_token_pad_ratio_percent_in_batch": 100
* (1 - text_token_total_num / (batch_size * text_token_max_len)),
}
if "audio_codes" in batch:
audio_codes_max_len = batch["audio_codes"].shape[-1]
audio_codes_total_num = batch["audio_codes_lens"].sum()
batch_info_dict.update(
{
"train/audio_codes_max_len": audio_codes_max_len,
"train/audio_codes_total_num_in_batch": audio_codes_total_num.item(),
"train/audio_codes_pad_ratio_percent_in_batch": 100
* (1 - audio_codes_total_num / (batch_size * audio_codes_max_len)),
}
)
else:
audio_samples_max_len = batch["audio"].shape[-1]
audio_samples_total_num = batch["audio_lens"].sum()
batch_info_dict.update(
{
"train/audio_samples_max_len": audio_samples_max_len,
"train/audio_samples_total_num_in_batch": audio_samples_total_num.item(),
"train/audio_samples_pad_ratio_percent_in_batch": 100
* (1 - audio_samples_total_num / (batch_size * audio_samples_max_len)),
}
)
self.log_dict(batch_info_dict, on_step=True)
return loss
def validation_step(self, batch, batch_idx):
batch_output = self.process_batch(batch, mode="val")
# self.process_batch returns a dict. We currently only log "logits" which come from the parallel prediction
# head. If we use local_transformer, then the local_transformer returns "local_transformer_logits"
loss = batch_output['loss']
codebook_loss = batch_output['codebook_loss']
alignment_loss = batch_output['alignment_loss']
aligner_encoder_loss = batch_output['aligner_encoder_loss']
logits = batch_output['logits']
audio_codes_target = batch_output['audio_codes_target']
audio_codes_lens_target = batch_output['audio_codes_lens_target']
context_audio_codes = batch_output['context_audio_codes']
context_audio_codes_lens = batch_output['context_audio_codes_lens']
attn_info = batch_output['attn_info']
text_lens = batch_output['text_lens']
dec_context_size = batch_output['dec_context_size']
if alignment_loss is None:
alignment_loss = torch.tensor(0.0, device=loss.device)
if aligner_encoder_loss is None:
aligner_encoder_loss = torch.tensor(0.0, device=loss.device)
if batch_idx == 0 and self.global_rank == 0:
# Prepare dictionary for aggregated wandb logging
wandb_log_dict = {}
# Get audio data for logging
wandb_log_dict.update(
self.log_val_audio_example(
logits, audio_codes_target, audio_codes_lens_target, context_audio_codes, context_audio_codes_lens
)
)
# Get attention image data for logging
if len(attn_info[self.transcript_decoder_layers[0]]['cross_attn_probabilities']) > 1:
# cross_attn_probabilities only returned when not using flash attention
cross_attention_probs = [
attn['cross_attn_probabilities'][0]
for layer_idx, attn in enumerate(attn_info)
if layer_idx in self.ctc_prior_layer_ids
]
wandb_log_dict.update(
self.log_attention_probs(
cross_attention_probs,
audio_codes_lens_target,
text_lens,
prefix="val",
dec_context_size=dec_context_size,
)
)
for layer_idx in self.transcript_decoder_layers:
cross_attention_probs = [attn_info[layer_idx]['cross_attn_probabilities'][0]]
wandb_log_dict.update(
self.log_attention_probs(
cross_attention_probs,
audio_codes_lens_target,
text_lens,
prefix=f"val/layer_{layer_idx}",
dec_context_size=dec_context_size,
)
)
if batch_output['aligner_attn_soft'] is not None:
wandb_log_dict.update(
self.log_attention_probs(
[batch_output['aligner_attn_soft']],
audio_codes_lens_target,
text_lens,
prefix="val/aligner_encoder_attn",
)
)
if batch_output['aligner_attn_hard'] is not None:
wandb_log_dict.update(
self.log_attention_probs(
[batch_output['aligner_attn_hard'].unsqueeze(1)],
audio_codes_lens_target,
text_lens,
prefix="val/aligner_encoder_attn_hard",
)
)
# Perform single wandb log call if wandb is active and there is data
for logger in self.loggers:
if isinstance(logger, WandbLogger) and wandb_log_dict:
logger.experiment.log(wandb_log_dict)
local_transformer_loss = batch_output['local_transformer_loss']
val_output = {
'val_loss': loss,
'val_codebook_loss': codebook_loss,
'val_alignment_loss': alignment_loss,
'val_local_transformer_loss': local_transformer_loss,
'val_aligner_encoder_loss': aligner_encoder_loss,
}
self.validation_step_outputs.append(val_output)
return val_output
def get_cross_attention_scores(self, attn_probs, filter_layers=None):
"""
Returns the cross attention probabilities for the last audio timestep
"""
mean_cross_attn_scores = []
all_heads_cross_attn_scores = []
for lidx, layerwise_attn_prob in enumerate(attn_probs):
if (filter_layers is not None and lidx not in filter_layers) or (
lidx not in self.transcript_decoder_layers
):
continue
cross_attn_prob = layerwise_attn_prob['cross_attn_probabilities'][
0
] # B, H, audio_timesteps, text_timesteps
mean_cross_attn_scores.append(cross_attn_prob.mean(dim=1)) # B, audio_timesteps, text_timesteps
for head_idx in range(cross_attn_prob.size(1)):
all_heads_cross_attn_scores.append(cross_attn_prob[:, head_idx, -1, :]) # B, text_timesteps
mean_cross_attn_scores = torch.stack(mean_cross_attn_scores, dim=1) # B, L, audio_timesteps, text_timesteps
mean_cross_attn_scores = mean_cross_attn_scores.mean(dim=1) # B, audio_timesteps, text_timesteps
last_audio_timestep_scores = mean_cross_attn_scores[:, -1, :] # B, text_timesteps
return last_audio_timestep_scores, all_heads_cross_attn_scores
def get_most_attended_text_timestep(
self,
alignment_attention_scores,
last_attended_timesteps,
text_lens,
lookahead_window_size,
attended_timestep_counter,
batch_size,
):
"""
Returns the most attended timestep for each batch item
"""
text_time_step_attended = []
for bidx in range(batch_size):
last_attended_timestep = last_attended_timesteps[-1][bidx]
if attended_timestep_counter[bidx].get(last_attended_timestep, 0) >= 8:
# This is probably an attention sink! Move to the next timestep
last_attended_timestep += 1
window_size = lookahead_window_size
window_end = min(last_attended_timestep + window_size, text_lens[bidx] - 3) # Ignore the last 3 timesteps
item_attention_scores = alignment_attention_scores[bidx, last_attended_timestep:window_end]
if item_attention_scores.size(0) == 0:
# This means the sentence has ended
attended_timestep = text_lens[bidx].item() - 1
else:
attended_timestep = item_attention_scores.argmax().item() + last_attended_timestep
text_time_step_attended.append(attended_timestep)
attended_timestep_counter[bidx][attended_timestep] = (
attended_timestep_counter[bidx].get(attended_timestep, 0) + 1
)
return text_time_step_attended, attended_timestep_counter
def construct_inference_prior(
self,
prior_epsilon,
cross_attention_scores,
text_lens,
text_time_step_attended,
attended_timestep_counter,
unfinished_texts,
finished_texts_counter,
end_indices,
lookahead_window_size,
batch_size,
):
# Attn prior for the next timestep
_attn_prior = torch.zeros(cross_attention_scores.shape[0], 1, cross_attention_scores.shape[1]) + prior_epsilon
_attn_prior = _attn_prior.to(cross_attention_scores.device)
for bidx in range(cross_attention_scores.shape[0]):
if bidx < batch_size:
_text_len = text_lens[bidx]
if text_lens[bidx] <= 5:
# Very short sentences, No Prior
_attn_prior[bidx, 0, :] = 1.0
else:
_attn_prior[bidx, 0, max(1, text_time_step_attended[bidx] - 1)] = (
1.0 # Slight exposure to history for better pronounciation. Not very important.
)
_attn_prior[bidx, 0, text_time_step_attended[bidx]] = (
1.0 # Slightly bias to continue moving forward. Not very important.
)
for ind in range(1, lookahead_window_size + 1):
_attn_prior[bidx, 0, min(text_time_step_attended[bidx] + ind, _text_len - 1)] = 1.0
# Penalize timesteps that have been attended to more than 10 times
for _timestep in attended_timestep_counter[bidx]:
if attended_timestep_counter[bidx][_timestep] >= 10:
# This means the timestep has been attended to more than 10 times (To avoid getting stuck)
_attn_prior[bidx, 0, : _timestep + 1] = prior_epsilon
unfinished_texts[bidx] = False
if text_time_step_attended[bidx] < text_lens[bidx] - 3:
# This means the sentence has not ended
if bidx not in end_indices:
unfinished_texts[bidx] = True
if text_time_step_attended[bidx] >= text_lens[bidx] - 2 or bidx in end_indices:
if bidx not in finished_texts_counter:
finished_texts_counter[bidx] = 0
for bidx in finished_texts_counter:
finished_texts_counter[bidx] += 1
if finished_texts_counter[bidx] > 5:
# This means we have been within the text EOS window for at least 5 timesteps
# We should allow EOS to be predicted now.
unfinished_texts[bidx] = False
return _attn_prior, unfinished_texts, finished_texts_counter
def get_inference_attention_plots(
self,
cross_attention_scores_all_timesteps,
all_heads_cross_attn_scores_all_timesteps,
text_lens,
predicted_codes_lens,
batch_size,
compute_all_heads_attn_maps,
last_attended_timestep,
):
last_attended_timestep = np.array(last_attended_timestep).T
cross_attention_scores_all_timesteps = torch.stack(
cross_attention_scores_all_timesteps, dim=2
) # B, text_timesteps, T'
headwise_cross_attention_scores_all_timesteps = []
for hidx in range(len(all_heads_cross_attn_scores_all_timesteps[0])):
head_cross_attention_all_timesteps = torch.stack(
[x[hidx] for x in all_heads_cross_attn_scores_all_timesteps], dim=2
) # B, text_timesteps, T'
headwise_cross_attention_scores_all_timesteps.append(head_cross_attention_all_timesteps)
cross_attention_maps = []
headwise_cross_attention_maps = []
for bidx in range(batch_size):
item_cross_attention_scores = cross_attention_scores_all_timesteps[
bidx, : text_lens[bidx], : predicted_codes_lens[bidx]
]
cross_attn_np = plot_alignment_to_numpy(
item_cross_attention_scores.cpu().numpy(),
attended=last_attended_timestep[bidx, : predicted_codes_lens[bidx]],
)
cross_attention_maps.append(cross_attn_np)
item_all_head_cross_attn_maps = []
if compute_all_heads_attn_maps:
for hidx in range(len(all_heads_cross_attn_scores_all_timesteps[0])):
item_headwise_cross_attention_scores = headwise_cross_attention_scores_all_timesteps[hidx][
bidx, : text_lens[bidx], : predicted_codes_lens[bidx]
]
headwise_cross_attn_np = plot_alignment_to_numpy(
item_headwise_cross_attention_scores.cpu().numpy(),
attended=last_attended_timestep[bidx, : predicted_codes_lens[bidx]],
)
item_all_head_cross_attn_maps.append(headwise_cross_attn_np)
headwise_cross_attention_maps.append(item_all_head_cross_attn_maps)
return cross_attention_maps, headwise_cross_attention_maps
def find_eos_frame_index(self, codes, eos_detection_method) -> Union[int, float]:
"""
Checks for EOS in the predicted codes. Returns the index of the first frame within the frame stack
that contains an EOS token across any codebook, or `None` if no EOS is found.
Args:
codes: (num_codebooks, frame_stacking_factor)
Returns:
index (within the frame stack) of the first frame with EOS, or `float('inf')` if no EOS is found
"""
eos_mask = codes == self.audio_eos_id # (codebooks, frame_stacking_factor)
detection_type = EOSDetectionMethod.detection_type(eos_detection_method)
if detection_type == "any":
eos_per_frame = eos_mask.any(
dim=0
) # (frame_stacking_factor,) - True if any codebook has EOS in this frame
elif detection_type == "all":
eos_per_frame = eos_mask.all(
dim=0
) # (frame_stacking_factor,) - True if all codebooks have EOS in this frame
elif detection_type == "zero_cb":
eos_per_frame = eos_mask[:1, :].any(
dim=0
) # (frame_stacking_factor,) - True if zeroth codebook has EOS in this frame
else:
raise ValueError(f"Invalid EOS detection method: {eos_detection_method}")
# find first frame with EOS
if eos_per_frame.any():
# return index of the first frame with EOS
return eos_per_frame.nonzero()[0].item()
return float('inf')
def detect_eos(self, audio_codes_multinomial, audio_codes_argmax, eos_detection_method) -> Union[int, float]:
"""
Detects EOS in the predicted codes. Returns the index of the first frame within the frame stack
that triggers EOS detection, or `float('inf')` if no EOS is found.
Args:
audio_codes_multinomial: (num_codebooks, frame_stacking_factor) - Multinomial samples
audio_codes_argmax: (num_codebooks, frame_stacking_factor) - Argmax samples
eos_detection_method: EOS detection method
Returns:
index (within the frame stack) of the first frame with EOS, or `float('inf')` if no EOS is found
"""
sampling_type = EOSDetectionMethod.sampling_type(eos_detection_method)
if sampling_type == "argmax":
return self.find_eos_frame_index(audio_codes_argmax, eos_detection_method)
elif sampling_type == "argmax_or_multinomial":
argmax_eos_frame = self.find_eos_frame_index(audio_codes_argmax, eos_detection_method)
multinomial_eos_frame = self.find_eos_frame_index(audio_codes_multinomial, eos_detection_method)
return min(argmax_eos_frame, multinomial_eos_frame)
else:
raise ValueError(f"Invalid EOS detection method: {eos_detection_method}")
def infer_batch(
self,
batch,
max_decoder_steps=500,
temperature=0.7,
topk=80,
use_cfg=False,
cfg_scale=1.0,
return_cross_attn_probs=False,
apply_attention_prior=False,
prior_epsilon=1e-5,
lookahead_window_size=10,
estimate_alignment_from_layers=None,
apply_prior_to_layers=None,
start_prior_after_n_audio_steps=10,
compute_all_heads_attn_maps=False,
use_local_transformer_for_inference=False,
use_LT_kv_cache=True,
maskgit_n_steps=3,
maskgit_noise_scale=0.0,
maskgit_fixed_schedule=None,
maskgit_dynamic_cfg_scale=False,
maskgit_sampling_type=None,
ignore_finished_sentence_tracking=False,
eos_detection_method="argmax_or_multinomial_any",
# Setting this greater than 0 prevents rare cases of first-frame termination. Any number greater between 1 and 4 should work, but 4
# lines up with the codec's minimum frame requirement.
min_generated_frames=4,
):
eos_detection_method = EOSDetectionMethod(eos_detection_method)
with torch.no_grad():
start_time = time.time()
self.decoder.reset_cache(use_cache=self.use_kv_cache_for_inference)
context_tensors = self.prepare_context_tensors(batch)
text = context_tensors['text']
audio_codes_bos = torch.full(
(text.size(0), self.num_audio_codebooks, self.frame_stacking_factor),
self.audio_bos_id,
device=text.device,
).long()
audio_codes_lens = torch.full(
(text.size(0),), 1, device=text.device
).long() # intetionally 1 rather than self.frame_stacking_factor since this is in stacked form
audio_codes_input = audio_codes_bos
audio_codes_mask = get_mask_from_lengths(audio_codes_lens)
all_predictions = []
end_indices = {}
if use_cfg:
dummy_cond, dummy_cond_mask, dummy_additional_decoder_input, dummy_addition_dec_mask, _ = (
self.prepare_dummy_cond_for_cfg(
context_tensors['cond'],
context_tensors['cond_mask'],
context_tensors['additional_decoder_input'],
context_tensors['additional_decoder_mask'],
)
)
cross_attention_scores_all_timesteps = []
all_heads_cross_attn_scores_all_timesteps = []
_attn_prior = None
unfinished_texts = {}
finished_texts_counter = {}
attended_timestep_counter = [{} for _ in range(text.size(0))]
last_attended_timesteps = [
[1 for _ in range(text.size(0))]
] # Maintain a list of attended timesteps as we predict audio for each batch item
time_to_first_prediction = 0.0
for idx in range(max_decoder_steps // self.frame_stacking_factor):
if idx == 1:
time_to_first_prediction = time.time() - start_time
if idx % 20 == 0:
print(f"Decoding timestep {idx}")
audio_codes_embedded = self.embed_audio_tokens(audio_codes_input)
if context_tensors['additional_decoder_input'] is not None:
_audio_codes_embedded = torch.cat(
[context_tensors['additional_decoder_input'], audio_codes_embedded], dim=1
)
_audio_codes_mask = torch.cat(
[context_tensors['additional_decoder_mask'], audio_codes_mask], dim=1
)
else:
_audio_codes_embedded = audio_codes_embedded
_audio_codes_mask = audio_codes_mask
if apply_prior_to_layers is not None:
attn_prior = [None for _ in range(self.decoder.n_layers)]
for layer_idx in apply_prior_to_layers:
attn_prior[layer_idx] = _attn_prior
else:
attn_prior = _attn_prior
if self.model_type == 'multi_encoder_context_tts':
attn_prior = [attn_prior, None]
if use_cfg:
batch_size = audio_codes_embedded.size(0)
if isinstance(context_tensors['cond'], list):
cfg_cond = [
torch.cat([cond_item, dummy_cond_item], dim=0)
for cond_item, dummy_cond_item in zip(context_tensors['cond'], dummy_cond)
]
cfg_cond_mask = [
torch.cat([cond_mask_item, dummy_cond_mask_item], dim=0)
for cond_mask_item, dummy_cond_mask_item in zip(
context_tensors['cond_mask'], dummy_cond_mask
)
]
else:
cfg_cond = torch.cat([context_tensors['cond'], dummy_cond], dim=0)
cfg_cond_mask = torch.cat([context_tensors['cond_mask'], dummy_cond_mask], dim=0)
cfg_audio_codes_embedded = torch.cat([_audio_codes_embedded, _audio_codes_embedded], dim=0)
cfg_audio_codes_mask = torch.cat([_audio_codes_mask, _audio_codes_mask], dim=0)
if dummy_additional_decoder_input is not None:
cfg_audio_codes_embedded[batch_size:, : dummy_additional_decoder_input.size(1)] = (
dummy_additional_decoder_input
)
cfg_audio_codes_mask[batch_size:, : dummy_additional_decoder_input.size(1)] = (
dummy_addition_dec_mask
)
# print(f"step {idx}")
# print(f"use_cfg {use_cfg}")
# print(f"shape {cfg_audio_codes_embedded.shape}")
# print(f"use kv cahce? {self.use_kv_cache_for_inference}")
combined_logits, attn_probs, dec_out = self.forward(
dec_input_embedded=cfg_audio_codes_embedded,
dec_input_mask=cfg_audio_codes_mask,
cond=cfg_cond,
cond_mask=cfg_cond_mask,
attn_prior=attn_prior,
multi_encoder_mapping=context_tensors['multi_encoder_mapping'],
)
cond_logits = combined_logits[:batch_size]
uncond_logits = combined_logits[batch_size:]
all_code_logits = (1 - cfg_scale) * uncond_logits + cfg_scale * cond_logits
else:
batch_size = audio_codes_embedded.size(0)
all_code_logits, attn_probs, dec_out = self.forward(
dec_input_embedded=_audio_codes_embedded,
dec_input_mask=_audio_codes_mask,
cond=context_tensors['cond'],
cond_mask=context_tensors['cond_mask'],
attn_prior=attn_prior,
multi_encoder_mapping=context_tensors['multi_encoder_mapping'],
)
if return_cross_attn_probs or apply_attention_prior:
cross_attention_scores, all_heads_cross_attn_scores = self.get_cross_attention_scores(
attn_probs
) # B, text_timesteps
alignment_attention_scores = cross_attention_scores
if estimate_alignment_from_layers is not None:
alignment_attention_scores, _ = self.get_cross_attention_scores(
attn_probs, filter_layers=estimate_alignment_from_layers
) # B, text_timesteps
cross_attention_scores_all_timesteps.append(cross_attention_scores)
all_heads_cross_attn_scores_all_timesteps.append(all_heads_cross_attn_scores)
if apply_attention_prior and idx >= start_prior_after_n_audio_steps:
text_time_step_attended, attended_timestep_counter = self.get_most_attended_text_timestep(
alignment_attention_scores=alignment_attention_scores,
last_attended_timesteps=last_attended_timesteps,
text_lens=context_tensors['text_lens'],
lookahead_window_size=lookahead_window_size,
attended_timestep_counter=attended_timestep_counter,
batch_size=batch_size,
)
last_attended_timesteps.append(text_time_step_attended)
_attn_prior, unfinished_texts, finished_texts_counter = self.construct_inference_prior(
prior_epsilon=prior_epsilon,
cross_attention_scores=cross_attention_scores,
text_lens=context_tensors['text_lens'],
text_time_step_attended=text_time_step_attended,
attended_timestep_counter=attended_timestep_counter,
unfinished_texts=unfinished_texts,
finished_texts_counter=finished_texts_counter,
end_indices=end_indices,
lookahead_window_size=lookahead_window_size,
batch_size=batch_size,
)
if ignore_finished_sentence_tracking:
finished_items = {}
unfinished_items = {}
else:
finished_items = {
k: v for k, v in finished_texts_counter.items() if v >= 20
} # Items that have been close to the end for atleast 20 timesteps
unfinished_items = {k: v for k, v in unfinished_texts.items() if v}
# Don't allow termination until we have generated at least `min_generated_frames` frames (rounded up to the nearest multiple of frame_stacking_factor)
# This guards against rare cases of termination right at the start of generation.
forbid_audio_eos = idx * self.frame_stacking_factor < min_generated_frames
all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook)
if use_local_transformer_for_inference:
if self.local_transformer_type == LocalTransformerType.AR:
# Autoregressive sampling with local transformer
audio_codes_next = self.local_transformer_sample_autoregressive(
dec_output=dec_out[:, -1, :],
temperature=temperature,
topk=topk,
unfinished_items=unfinished_items,
finished_items=finished_items,
use_cfg=use_cfg,
cfg_scale=cfg_scale,
use_kv_cache=use_LT_kv_cache,
forbid_audio_eos=forbid_audio_eos,
)
elif self.local_transformer_type == LocalTransformerType.MASKGIT:
audio_codes_next = self.local_transformer_sample_maskgit(
dec_output=dec_out[:, -1, :],
temperature=temperature,
topk=topk,
unfinished_items=unfinished_items,
finished_items=finished_items,
use_cfg=use_cfg,
cfg_scale=cfg_scale,
n_steps=maskgit_n_steps,
noise_scale=maskgit_noise_scale,
fixed_schedule=maskgit_fixed_schedule,
dynamic_cfg_scale=maskgit_dynamic_cfg_scale,
sampling_type=maskgit_sampling_type,
forbid_audio_eos=forbid_audio_eos,
)
else:
raise ValueError(
f"Local transformer inference requested by but local transformer type is {self.local_transformer_type}"
)
else:
# Parallel sampling from all codebooks
audio_codes_next = self.sample_codes_from_logits(
all_code_logits_t,
temperature=temperature,
topk=topk,
unfinished_items=unfinished_items,
finished_items=finished_items,
forbid_audio_eos=forbid_audio_eos,
) # (B, num_codebooks, frame_stacking_factor)
all_codes_next_argmax = self.sample_codes_from_logits(
all_code_logits_t,
temperature=0.01,
topk=1,
unfinished_items=unfinished_items,
finished_items=finished_items,
forbid_audio_eos=forbid_audio_eos,
) # (B, num_codebooks, frame_stacking_factor)
for item_idx in range(all_codes_next_argmax.size(0)):
if item_idx not in end_indices:
end_frame_index = self.detect_eos(
audio_codes_next[item_idx], all_codes_next_argmax[item_idx], eos_detection_method
)
if end_frame_index != float('inf'):
global_index = idx * self.frame_stacking_factor + end_frame_index
end_indices[item_idx] = global_index
print(f"End detected for item {item_idx} at decoder timestep: {idx}")
all_predictions.append(audio_codes_next)
audio_codes_input = torch.cat([audio_codes_input, audio_codes_next], dim=-1) # (B, C, T')
audio_codes_lens = audio_codes_lens + 1 # already in stacked form
audio_codes_mask = get_mask_from_lengths(audio_codes_lens)
if len(end_indices) == text.size(0) and len(all_predictions) >= 4:
# Codec must be of atleast 4 timesteps to be decoded properly
print("All ends reached")
break
tts_generation_time = time.time() - start_time
tts_generation_time_per_frame = tts_generation_time / (len(all_predictions) * self.frame_stacking_factor)
# Concatenate the list of predictions along the time dimension. Note that when frame stacking is on,
# this also undoes the stacking.
predicted_codes = torch.cat(all_predictions, dim=-1) # (B, num_codebooks, T')
predicted_lens = [
end_indices.get(idx, max_decoder_steps) for idx in range(text.size(0))
] # Ensure that the codec is atleast of length 4
predicted_codes_lens = torch.tensor(predicted_lens, device=text.device).long()
predicted_audio, predicted_audio_lens = self.codes_to_audio(predicted_codes, predicted_codes_lens)
end_time = time.time()
total_audio_duration_generated = (
predicted_audio_lens.max().item() * predicted_audio_lens.shape[0]
) / self.sample_rate
rtf = total_audio_duration_generated / (end_time - start_time)
rtf_metrics = {
'rtf': rtf,
'time_to_first_prediction': time_to_first_prediction,
'tts_generation_time': tts_generation_time,
'max_frames_generated': len(all_predictions),
'tts_generation_time_per_frame': tts_generation_time_per_frame,
'batch_size': text.size(0),
}
torch.cuda.empty_cache()
cross_attention_maps = None
headwise_cross_attention_maps = None
if return_cross_attn_probs:
cross_attention_maps, headwise_cross_attention_maps = self.get_inference_attention_plots(
cross_attention_scores_all_timesteps,
all_heads_cross_attn_scores_all_timesteps,
context_tensors['text_lens'],
predicted_codes_lens,
text.size(0),
compute_all_heads_attn_maps,
last_attended_timesteps,
)
return InferBatchOutput(
predicted_audio=predicted_audio,
predicted_audio_lens=predicted_audio_lens,
predicted_codes=predicted_codes,
predicted_codes_lens=predicted_codes_lens,
rtf_metrics=rtf_metrics,
cross_attention_maps=cross_attention_maps,
headwise_cross_attention_maps=headwise_cross_attention_maps,
)
def test_step(self, batch, batch_idx):
with torch.no_grad():
test_dl_batch_size = self._test_dl.batch_size
temperature = self.cfg.get('inference_temperature', 0.7)
topk = self.cfg.get('inference_topk', 80)
use_cfg = self.cfg.get('inference_use_cfg', False)
cfg_scale = self.cfg.get('inference_cfg_scale', 1.0)
output = self.infer_batch(
batch,
max_decoder_steps=self.cfg.get('max_decoder_steps', 500),
temperature=temperature,
topk=topk,
use_cfg=use_cfg,
cfg_scale=cfg_scale,
)
predicted_audio = output.predicted_audio
predicted_audio_lens = output.predicted_audio_lens
for logger in self.loggers:
is_wandb = isinstance(logger, WandbLogger)
is_tb = isinstance(logger, TensorBoardLogger)
if not is_wandb and not is_tb:
raise ValueError(
"Invalid logger type for audio logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported."
)
for idx in range(predicted_audio.size(0)):
predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy()
predicted_audio_np = predicted_audio_np[: predicted_audio_lens[idx]]
item_idx = batch_idx * test_dl_batch_size + idx
if is_wandb:
log_dict = {
"test/predicted_audio": wandb.Audio(
predicted_audio_np, sample_rate=self.sample_rate, caption="Predicted Audio"
),
}
logger.experiment.log(log_dict, step=item_idx)
if is_tb:
logger.experiment.add_audio(
'test/predicted_audio',
predicted_audio_np,
global_step=item_idx,
sample_rate=self.sample_rate,
)
# Save the predicted audio
log_dir = logger.log_dir
audio_dir = os.path.join(log_dir, 'audios')
if not os.path.exists(audio_dir):
os.makedirs(audio_dir)
audio_path = os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}.wav')
sf.write(audio_path, predicted_audio_np, self.sample_rate)
def on_validation_epoch_end(self):
collect = lambda key: torch.stack([x[key] for x in self.validation_step_outputs]).mean()
val_loss = collect("val_loss")
val_codebook_loss = collect("val_codebook_loss")
val_alignment_loss = collect("val_alignment_loss")
val_aligner_encoder_loss = collect("val_aligner_encoder_loss")
# log val_loss in the same group as the other val metrics.
self.log("val/loss", val_loss, prog_bar=True, sync_dist=True)
# ensure val_loss is available for epoch-level checkpointing and filename generation without cluttering wandb logs.
self.log(
"val_loss",
val_loss,
prog_bar=False,
sync_dist=True,
on_step=False,
on_epoch=True,
logger=False,
enable_graph=False,
)
self.log("val/codebook_loss", val_codebook_loss, prog_bar=True, sync_dist=True)
self.log("val/alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True)
self.log("val/aligner_encoder_loss", val_aligner_encoder_loss, prog_bar=True, sync_dist=True)
if self.local_transformer_type != LocalTransformerType.NO_LT:
val_local_transformer_loss = collect("val_local_transformer_loss")
self.log("val/local_transformer_loss", val_local_transformer_loss, prog_bar=True, sync_dist=True)
self.validation_step_outputs.clear() # free memory
def get_dataset(self, dataset_cfg, dataset_type):
dataset = instantiate(
dataset_cfg.dataset,
sample_rate=self.sample_rate,
bos_id=self.bos_id,
eos_id=self.eos_id,
audio_bos_id=self.audio_bos_id,
audio_eos_id=self.audio_eos_id,
context_audio_bos_id=self.context_audio_bos_id,
context_audio_eos_id=self.context_audio_eos_id,
num_audio_codebooks=self.data_num_audio_codebooks,
codec_model_samples_per_frame=self.codec_model_samples_per_frame,
prior_scaling_factor=self.cfg.prior_scaling_factor,
load_cached_codes_if_available=self.cfg.load_cached_codes_if_available,
dataset_type=dataset_type, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn)
use_text_conditioning_tokenizer=self.cfg.use_text_conditioning_encoder,
text_conditioning_tokenizer_name=self.text_conditioning_tokenizer_name,
pad_context_text_to_max_duration=self.pad_context_text_to_max_duration,
context_duration_min=self.cfg.context_duration_min,
context_duration_max=self.cfg.context_duration_max,
text_context_remapping=self.text_context_remapping,
text_context_remapping_prob=self.text_context_remapping_prob,
)
dataset.load_16khz_audio = False
dataset.tokenizer_config = (
self.cfg.text_tokenizers
) # This will be used in worker_init_fn for instantiating tokenizer
return dataset
def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.DataLoader:
# TODO @xueyang: better to distinguish cfg. self.cfg is the model cfg, while cfg here is train_ds cfg. Also
# cfg is a classifier-free guidance.
dataset = MagpieTTSLhotseDataset(
sample_rate=self.sample_rate,
volume_norm=dataset_cfg.volume_norm,
codec_model_samples_per_frame=self.codec_model_samples_per_frame,
audio_bos_id=self.audio_bos_id,
audio_eos_id=self.audio_eos_id,
context_audio_bos_id=self.context_audio_bos_id,
context_audio_eos_id=self.context_audio_eos_id,
num_audio_codebooks=self.data_num_audio_codebooks,
prior_scaling_factor=self.cfg.prior_scaling_factor,
load_cached_codes_if_available=self.cfg.load_cached_codes_if_available,
dataset_type=mode, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn)
load_16khz_audio=False,
pad_context_text_to_max_duration=self.pad_context_text_to_max_duration,
context_duration_min=self.cfg.context_duration_min,
context_duration_max=self.cfg.context_duration_max,
use_text_conditioning_tokenizer=self.cfg.use_text_conditioning_encoder,
text_conditioning_tokenizer_name=self.text_conditioning_tokenizer_name,
tokenizer_config=self.cfg.text_tokenizers,
text_context_remapping=self.text_context_remapping,
text_context_remapping_prob=self.text_context_remapping_prob,
)
data_loader = get_lhotse_dataloader_from_config(
config=dataset_cfg.dataset,
global_rank=self.global_rank,
world_size=self.world_size,
dataset=dataset,
)
return data_loader
def setup_training_data(self, dataset_cfg):
if dataset_cfg.get("use_lhotse", False):
# TODO @xueyang: better to distinguish cfg. self.cfg is the model cfg, while cfg here is train_ds cfg. Also
# cfg is a classifier-free guidance.
# specify target sampling rate the same as codec model's because lhotse config defaults 16_000.
if not isinstance(dataset_cfg, DictConfig):
dataset_cfg = OmegaConf.create(dataset_cfg)
OmegaConf.set_struct(dataset_cfg.dataset, False)
dataset_cfg.dataset.update({"sample_rate": self.sample_rate})
OmegaConf.set_struct(dataset_cfg.dataset, True)
self._train_dl = self.get_lhotse_dataloader(dataset_cfg, mode='train')
else:
dataset = self.get_dataset(dataset_cfg, dataset_type='train')
sampler = dataset.get_sampler(dataset_cfg.dataloader_params.batch_size, world_size=self.trainer.world_size)
persistent_workers = True
if dataset_cfg.dataloader_params.num_workers == 0:
persistent_workers = False
# For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable)
dataset.text_tokenizer = setup_tokenizers(
all_tokenizers_config=self.cfg.text_tokenizers,
mode='train',
)
self._train_dl = torch.utils.data.DataLoader(
dataset,
collate_fn=dataset.collate_fn,
sampler=sampler,
**dataset_cfg.dataloader_params,
worker_init_fn=worker_init_fn,
persistent_workers=persistent_workers,
)
def _setup_test_dataloader(self, dataset_cfg) -> torch.utils.data.DataLoader:
if dataset_cfg.get("use_lhotse", False):
# specify target sampling rate the same as codec model's because lhotse config defaults 16_000.
if not isinstance(dataset_cfg, DictConfig):
dataset_cfg = OmegaConf.create(dataset_cfg)
OmegaConf.set_struct(dataset_cfg.dataset, False)
dataset_cfg.dataset.update({"sample_rate": self.sample_rate})
OmegaConf.set_struct(dataset_cfg.dataset, True)
data_loader = self.get_lhotse_dataloader(dataset_cfg, mode='test')
else:
dataset = self.get_dataset(dataset_cfg, dataset_type='test')
persistent_workers = True
if dataset_cfg.dataloader_params.num_workers == 0:
persistent_workers = False
# For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable)
dataset.text_tokenizer = setup_tokenizers(all_tokenizers_config=self.cfg.text_tokenizers, mode='test')
data_loader = torch.utils.data.DataLoader(
dataset,
collate_fn=dataset.collate_fn,
**dataset_cfg.dataloader_params,
worker_init_fn=worker_init_fn,
persistent_workers=persistent_workers,
)
return data_loader
def setup_validation_data(self, dataset_cfg):
self._validation_dl = self._setup_test_dataloader(dataset_cfg)
def setup_test_data(self, dataset_cfg):
self._test_dl = self._setup_test_dataloader(dataset_cfg)
def setup_dummy_text_context_in_batch(
self,
batch: Dict[str, torch.Tensor],
) -> bool:
"""Setup dummy text context tensors in the batch dictionary.
"""
# No text context provided - set up dummy if model requires text conditioning tensors
dummy_context_text = "[NO TEXT CONTEXT]"
dummy_tokens = self.tokenizer.encode(
text=dummy_context_text, tokenizer_name=self.text_conditioning_tokenizer_name
)
batch['context_text_tokens'] = torch.tensor([dummy_tokens], device=self.device, dtype=torch.long)
batch['context_text_tokens_lens'] = torch.tensor([len(dummy_tokens)], device=self.device, dtype=torch.long)
batch['has_text_context'] = torch.tensor([False], device=self.device, dtype=torch.bool)
def setup_dummy_audio_context_in_batch(
self,
batch: Dict[str, torch.Tensor],
context_audio: Optional[torch.Tensor] = None,
context_audio_lens: Optional[torch.Tensor] = None,
) -> bool:
"""Setup dummy audio context tensors in the batch dictionary.
"""
# Model has baked context - create minimal dummy context tensors
# These will be ignored in prepare_context_tensors when baked embedding is used
dummy_context_codes = torch.zeros(
1, self.num_audio_codebooks, 2, device=self.device, dtype=torch.long
)
dummy_context_codes[:, :, 0] = self.context_audio_bos_id
dummy_context_codes[:, :, 1] = self.context_audio_eos_id
batch['context_audio_codes'] = dummy_context_codes
batch['context_audio_codes_lens'] = torch.tensor([2], device=self.device, dtype=torch.long)
def do_tts(
self,
transcript: str,
language: str = "en",
apply_TN: bool = False,
temperature: float = 0.7,
topk: int = 80,
max_decoder_steps: int = 500,
use_cfg: bool = True,
cfg_scale: float = 2.5,
) -> tuple:
"""
Generate speech from raw text transcript.
This is a convenience method for single-utterance text-to-speech synthesis.
For batch processing, use `infer_batch` directly. Only supports baked context embedding
context injection, NO audio conditioning and text conditioning.
Custom voice generation is not supported by this method.
Args:
transcript: Raw text to synthesize.
language: Language code for text normalization and tokenization.
Supported values depend on model's tokenizer configuration.
Common: "en" (English), "de" (German), "es" (Spanish), etc.
apply_TN: Whether to apply text normalization to the transcript.
If True, uses nemo_text_processing for normalization.
temperature: Sampling temperature for token generation.
topk: Top-k sampling parameter.
max_decoder_steps: Maximum number of decoder steps.
use_cfg: Whether to use classifier-free guidance.
cfg_scale: Scale factor for classifier-free guidance.
Returns:
Tuple of (audio, audio_len) where:
audio: Generated audio waveform. Shape: (1, T_audio).
audio_len: Length of generated audio in samples. Shape: (1,).
Raises:
ValueError: If model does not have a baked context embedding.
ImportError: If apply_TN=True but nemo_text_processing is not installed.
Example:
>>> # If text does not need to be normalized
>>> audio, audio_len = model.do_tts("Hello, how are you today?")
>>>
>>> # If text needs to be normalized
>>> audio, audio_len = model.do_tts(
... "Hello, how are you today?",
... apply_TN=True,
... )
"""
assert self.has_baked_context_embedding, "Model does not have a baked context embedding. Please use a checkpoint with a baked context embedding."
# Apply text normalization if requested
normalized_text = transcript
if apply_TN:
try:
from nemo_text_processing.text_normalization.normalize import Normalizer
normalizer = Normalizer(input_case='cased', lang=language)
normalized_text = normalizer.normalize(transcript, verbose=False)
logging.debug(f"Text normalization: '{transcript}' -> '{normalized_text}'")
except ImportError:
logging.warning(
"nemo_text_processing not installed. Skipping text normalization. "
"Install with: pip install nemo_text_processing"
)
# Determine tokenizer name based on language
# Try to find a matching tokenizer, fallback to first available
tokenizer_name = None
available_tokenizers = list(self.tokenizer.tokenizers.keys())
print(f"Available tokenizers: {available_tokenizers}")
# Common mappings for tokenizer names
language_tokenizer_map = {
"en": ["english_phoneme", "english"],
"de": ["german_phoneme", "german"],
"es": ["spanish_phoneme", "spanish"],
"fr": ["french_phoneme", "french"],
"it": ["italian_phoneme", "italian"],
"vi": ["vietnamese_phoneme", "vietnamese"],
"zh": ["mandarin_phoneme", "mandarin", "chinese"],
}
# Find matching tokenizer
if language in language_tokenizer_map:
for candidate in language_tokenizer_map[language]:
if candidate in available_tokenizers:
tokenizer_name = candidate
break
# Fallback to first available tokenizer
if tokenizer_name is None:
tokenizer_name = available_tokenizers[0]
logging.info(
f"No tokenizer found for language '{language}'. "
f"Using '{tokenizer_name}'. Available: {available_tokenizers}"
)
# Tokenize the transcript text
tokens = self.tokenizer.encode(text=normalized_text, tokenizer_name=tokenizer_name)
tokens = tokens + [self.eos_id] # Add EOS token (BOS not used per dataset convention)
text_tensor = torch.tensor([tokens], device=self.device, dtype=torch.long)
text_lens = torch.tensor([len(tokens)], device=self.device, dtype=torch.long)
# Create batch dictionary
batch = {
'text': text_tensor,
'text_lens': text_lens,
}
# Setup context in batch
if self.use_text_conditioning_encoder:
self.setup_dummy_text_context_in_batch(batch)
self.setup_dummy_audio_context_in_batch(batch)
# Run inference
with torch.no_grad():
output = self.infer_batch(
batch,
max_decoder_steps=max_decoder_steps,
temperature=temperature,
topk=topk,
use_cfg=use_cfg,
cfg_scale=cfg_scale,
)
return output.predicted_audio, output.predicted_audio_lens
@classmethod
def list_available_models(cls) -> List[PretrainedModelInfo]:
return []
|