File size: 158,486 Bytes
0558aa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import random
import time
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, List, Optional, Union

import numpy as np
import soundfile as sf
import torch
import wandb
from hydra.utils import instantiate
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
from omegaconf import DictConfig, OmegaConf, open_dict
from torch import nn
from torch.utils.data import get_worker_info

from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.collections.tts.data.text_to_speech_dataset_lhotse import MagpieTTSLhotseDataset, setup_tokenizers
from nemo.collections.tts.losses.aligner_loss import ForwardSumLoss
from nemo.collections.tts.models import AudioCodecModel
from nemo.collections.tts.modules import transformer_2501
from nemo.collections.tts.modules.aligner import AlignmentEncoder
from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter
from nemo.collections.tts.modules.magpietts_modules import (
    CharAwareSubwordEncoder,
    EOSDetectionMethod,
    LocalTransformerType,
    SpecialAudioToken,
    cosine_schedule,
)
from nemo.collections.tts.parts.utils.helpers import (
    binarize_attention_parallel,
    get_mask_from_lengths,
    plot_alignment_to_numpy,
)
from nemo.core.classes import ModelPT
from nemo.core.classes.common import PretrainedModelInfo
from nemo.utils import logging


@dataclass
class InferBatchOutput:
    """Output dataclass for MagpieTTS infer_batch method.

    This provides a consistent return type regardless of which optional outputs
    are requested.

    Attributes:
        predicted_audio: Generated audio waveforms. Shape: (B, T_audio).
        predicted_audio_lens: Length of each audio in samples. Shape: (B,).
        predicted_codes: Generated audio codec tokens. Shape: (B, num_codebooks, T_frames).
        predicted_codes_lens: Length of each code sequence in frames. Shape: (B,).
        rtf_metrics: Dictionary containing real-time factor and timing metrics.
        cross_attention_maps: Optional cross-attention visualization maps.
            List of numpy arrays, one per batch item. Only populated if
            return_cross_attn_probs=True.
        headwise_cross_attention_maps: Optional per-head cross-attention maps.
            Only populated if return_cross_attn_probs=True and
            compute_all_heads_attn_maps=True.
    """

    predicted_audio: torch.Tensor
    predicted_audio_lens: torch.Tensor
    predicted_codes: torch.Tensor
    predicted_codes_lens: torch.Tensor
    rtf_metrics: Dict[str, Any]
    cross_attention_maps: Optional[List[Any]] = None
    headwise_cross_attention_maps: Optional[List[Any]] = None


def worker_init_fn(worker_id):
    # For mp.set_start_method("spawn", force=True)
    # The dataset class should be picklable, so we initialize non-picklable objects here
    logging.info(f"Worker {worker_id} initializing...")
    worker_info = get_worker_info()
    dataset = worker_info.dataset  # Get the dataset instance in this worker
    tokenizer = setup_tokenizers(dataset.tokenizer_config, mode=dataset.dataset_type)
    dataset.text_tokenizer = tokenizer


class MagpieTTSModel(ModelPT):
    """
    Magpie-TTS Model Base Class used for training a TTS model that can generate audio codes from transcript and a context
    audio/text

    Supports multiple model types:

    - multi_encoder_context_tts: Transcript and context audio go to different encoders. Transcript encoding feeds to
    layers given by cfg.model.transcript_decoder_layers and the context encoding feeds into the layers given by
    context_decoder_layers .Also supports text context which gets encoded by the same encoder as context audio.
    Only one of context audio or contex text is supported.

    - decoder_context_tts: Text goes into the encoder; context & target audio go to the decoder. Also supports text
    context. Supports fixed sized context so we set context_duration_min and context_duration_max to the same
    value (5 seconds). Text context, which is usually shorter than number of codec frames of 5 second of audio, is
    padded to the max context duration in this model.

    - decoder_ce: Same as decoder_context_tts except there is a small neural network between the context tensors and
    the decoder input.
    """

    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.num_nodes * trainer.num_devices

        # load codec, disable loading of loss modules not needed during inference
        codec_model_path = cfg.get('codecmodel_path')
        if codec_model_path.startswith('nvidia/'):
            codec_model = AudioCodecModel.from_pretrained(codec_model_path)
        else:
            codec_model_cfg = AudioCodecModel.restore_from(codec_model_path, return_config=True)
            if "use_scl_loss" in codec_model_cfg:
                codec_model_cfg.use_scl_loss = False
            codec_model = AudioCodecModel.restore_from(
                codec_model_path, strict=False, override_config_path=codec_model_cfg
            )
        self.sample_rate = codec_model.sample_rate
        self.codec_model_samples_per_frame = codec_model.samples_per_frame
        # del codec discriminator to free memory
        del codec_model.discriminator

        # When using FSQ tokens, the codebook structure can be changed at any time.
        # An FSQ definition can be provided in `vector_quantizer` config to train with a codebook structure
        # that is different than in the audio codec checkpoint.
        vector_quantizer = cfg.get('vector_quantizer')
        if vector_quantizer is not None:
            vector_quantizer = instantiate(vector_quantizer)
            num_audio_codebooks = vector_quantizer.num_codebooks
            codebook_size = vector_quantizer.codebook_size
            codec_converter = VectorQuantizerIndexConverter(
                vector_quantizer_original=codec_model.vector_quantizer,
                vector_quantizer_new=vector_quantizer,
            )
            data_num_audio_codebooks = codec_model.vector_quantizer.num_codebooks
        else:
            num_audio_codebooks = codec_model.num_codebooks
            data_num_audio_codebooks = num_audio_codebooks
            codebook_size = codec_model.codebook_size
            codec_converter = None
        # The dataloader needs to know the number of codebooks that the context codes were stored in
        # In the case where there are no context codes saved, and there is no context audio (in the text context path),
        # We create a dummy context code tensor that is only [context_BOS, context_EOS] that is repeated for
        # data_num_audio_codebooks
        self.data_num_audio_codebooks = data_num_audio_codebooks
        self.num_audio_codebooks = num_audio_codebooks
        self.codebook_size = codebook_size

        # Our codebooks start with actual audio codec tokens, followed by special tokens.
        # The `forced_*` options are for backward compatibility for models trained with older code.
        get_token_index = partial(SpecialAudioToken.get_index, base_codebook_size=self.codebook_size)
        self.audio_bos_id = cfg.get('forced_audio_bos_id', get_token_index(SpecialAudioToken.AUDIO_BOS))
        self.audio_eos_id = cfg.get('forced_audio_eos_id', get_token_index(SpecialAudioToken.AUDIO_EOS))
        self.context_audio_bos_id = cfg.get(
            'forced_context_audio_bos_id', get_token_index(SpecialAudioToken.AUDIO_CONTEXT_BOS)
        )
        self.context_audio_eos_id = cfg.get(
            'forced_context_audio_eos_id', get_token_index(SpecialAudioToken.AUDIO_CONTEXT_EOS)
        )
        self.mask_token_id = cfg.get('forced_mask_token_id', get_token_index(SpecialAudioToken.MASK_TOKEN))
        self.num_all_tokens_per_codebook = cfg.get(
            'forced_num_all_tokens_per_codebook', self.codebook_size + len(SpecialAudioToken)
        )
        self.use_bpe_char_tokenizer = cfg.get('use_bpe_char_tokenizer', False)

        # The frame stacking factor controls how many consecutive frames are processed together by the base decoder
        # (and then refined into individual frames by the local transformer). A frame stacking factor of 1 means no
        # frame stacking. We have a separate embedding table for each of the stacked frames, e.g. for frame stacking
        # factor of 3, the entries of codebook 0 appear 3 times in the embedding table.
        self.frame_stacking_factor = cfg.get('frame_stacking_factor', 1)
        assert 'downsample_factor' not in cfg, '`downsample_factor` is deprecated, use `frame_stacking_factor` instead'
        # Setup tokenizer
        if hasattr(cfg, 'text_tokenizer'):
            # For backward compatibility for English-only models
            with open_dict(cfg):
                cfg.text_tokenizers = {"english_phoneme": cfg.text_tokenizer}
                del cfg['text_tokenizer']

        self.use_text_conditioning_encoder = cfg.get('use_text_conditioning_encoder', False)
        # Using google-t5/t5-small as default text conditioning tokenizer for backward compatibility.
        self.text_conditioning_tokenizer_name = cfg.get('text_conditioning_tokenizer_name', None)
        self.legacy_text_conditioning = cfg.get('legacy_text_conditioning', False)

        if self.legacy_text_conditioning:
            if self.text_conditioning_tokenizer_name is None:
                self.text_conditioning_tokenizer_name = "google-t5/t5-small"

            tokenizer_target = "AutoTokenizer"
            if self.text_conditioning_tokenizer_name == "google-t5/t5-small":
                tokenizer_target = "T5Tokenizer"

            with open_dict(cfg):
                cfg.text_tokenizers[self.text_conditioning_tokenizer_name] = {
                    '_target_': tokenizer_target,
                    'pretrained_model': self.text_conditioning_tokenizer_name,
                }
        elif self.text_conditioning_tokenizer_name is None:
            # If no text_conditioning_tokenizer_name is specified, use the first one as default
            # For text context tokenization
            self.text_conditioning_tokenizer_name = list(cfg.text_tokenizers.keys())[0]

        # TODO @xueyang: both tokenizers are only used to get some token ids. We
        # should kill them to save a small amount of mem resources since dataloader will initialize them
        # again after the worker processes are spawned.
        self.tokenizer = setup_tokenizers(
            all_tokenizers_config=cfg.text_tokenizers,
            mode='train',
        )

        num_tokens_tokenizer = len(self.tokenizer.tokens)
        if self.legacy_text_conditioning:
            # Text context tokens are not a part of the the regular transcript embedding table in legacy models
            num_tokens_tokenizer -= self.tokenizer.num_tokens_per_tokenizer[self.text_conditioning_tokenizer_name]

        num_tokens = num_tokens_tokenizer + 2  # +2 for BOS and EOS
        self.bos_id = num_tokens - 2
        self.eos_id = num_tokens - 1

        self.model_type = cfg.get('model_type', None)
        self.pad_context_text_to_max_duration = self.model_type in ['decoder_context_tts', 'decoder_ce']
        self.use_kv_cache_for_inference = cfg.get('use_kv_cache_for_inference', False)

        # Below args (text_context_remapping_json, text_context_remapping_prob) are
        # for combining multiple context_texts into a single one during training.
        # Eg. if we want to treat Emma_neutral and Emma_conversational as one speaker,
        # we can create an override dict {'Emma_neutral' : 'Emma', 'Emma_conversational' : 'Emma'}
        # This dict is saved in a json file given by cfg.model.text_context_remapping_json
        # If we want to preserve both behaviours i.e (Emma_neutral, Emma_conversational) and just (Emma)
        # we can do this mapping with a probability during training, as specified by text_context_remapping_prob
        self.text_context_remapping = None
        text_context_remapping_json = cfg.get('text_context_remapping_json', None)
        self.text_context_remapping_prob = cfg.get('text_context_remapping_prob', 0.0)
        if text_context_remapping_json is not None:
            with open(text_context_remapping_json, 'r') as f:
                self.text_context_remapping = json.load(f)

        super().__init__(cfg=cfg, trainer=trainer)

        if self.legacy_text_conditioning:
            tc_tokenizer = self.tokenizer.tokenizers[self.text_conditioning_tokenizer_name]
            self.context_text_embedding = nn.Embedding(tc_tokenizer.vocab_size, cfg.embedding_dim)

        # This needs to happen after super().__init__()
        self._codec_model = codec_model
        self._codec_model.freeze()  # Lightning does requires_grad = False and self.eval()
        self._codec_converter = codec_converter

        audio_embeddings = []
        for _ in range(self.num_audio_codebooks * self.frame_stacking_factor):
            audio_embeddings.append(nn.Embedding(self.num_all_tokens_per_codebook, cfg.embedding_dim))
        self.audio_embeddings = nn.ModuleList(audio_embeddings)

        if self.use_bpe_char_tokenizer:
            # BPE char tokenizer
            assert len(self.tokenizer.tokenizers) == 1, "BPE char tokenizer should only be used with one tokenizer"
            tokenizer_name = self.tokenizer.tokenizer_names[0]
            tokenizer = self.tokenizer.tokenizers[tokenizer_name]
            subword_vocab = tokenizer.get_vocab()
            # special tokens will be stored as it is in the char_vocab
            # Each special token will only be mapped to one char id
            special_vocab = {
                '<BOS>': self.bos_id,
                '<EOS>': self.eos_id,
            }
            self.cas_encoder = CharAwareSubwordEncoder(
                d_embed=cfg.embedding_dim,
                llm_tokenizer_vocab=subword_vocab,
                subword_padding_idx=self.tokenizer.pad,
                special_vocab=special_vocab,
            )
        else:
            # Regular text embedding
            self.text_embedding = nn.Embedding(num_tokens, cfg.embedding_dim)

        self.encoder = transformer_2501.Transformer(**dict(cfg.encoder))
        self.decoder = transformer_2501.Transformer(**dict(cfg.decoder))
        self.final_proj = nn.Linear(
            cfg.decoder.d_model,
            self.num_audio_codebooks * self.num_all_tokens_per_codebook * self.frame_stacking_factor,
        )

        self.local_transformer_type = LocalTransformerType(cfg.get('local_transformer_type', 'none').lower())
        logging.info(f"Local transformer type: {self.local_transformer_type}")
        if self.local_transformer_type != LocalTransformerType.NO_LT:
            local_transformer_hidden_dim = cfg.get('local_transformer_hidden_dim', 256)
            if local_transformer_hidden_dim != cfg.decoder.d_model:
                self.local_transformer_in_projection = nn.Linear(cfg.decoder.d_model, local_transformer_hidden_dim)
            else:
                self.local_transformer_in_projection = nn.Identity()
            self.local_transformer = transformer_2501.Transformer(
                n_layers=self.cfg.get('local_transformer_n_layers', 2),
                d_model=local_transformer_hidden_dim,
                d_ffn=local_transformer_hidden_dim * 4,
                sa_n_heads=self.cfg.get('local_transformer_n_heads', 1),
                kernel_size=1,
                is_causal=self.local_transformer_type == LocalTransformerType.AR,
                max_length_causal_mask=self.frame_stacking_factor * self.num_audio_codebooks + 2,
                use_learnable_pos_emb=True,
            )
            local_transformer_out_projections = []
            for _ in range(self.num_audio_codebooks * self.frame_stacking_factor):
                # Have a separate projection layer for each codebook, to distinguish between them
                local_transformer_out_projections.append(
                    nn.Linear(local_transformer_hidden_dim, self.num_all_tokens_per_codebook)
                )
            self.local_transformer_out_projections = nn.ModuleList(local_transformer_out_projections)

        if cfg.get('use_alignment_encoder', False):
            self.alignment_encoder = AlignmentEncoder(
                n_mel_channels=cfg.embedding_dim,
                n_text_channels=cfg.embedding_dim,
                dist_type="cosine",
                temperature=15.0,
            )

        if self.model_type == 'multi_encoder_context_tts':
            logging.warning(f"The multi_encoder_context_tts model type for {self} is deprecated.")

            # Transcript and context audio/text go to different encoders.
            # Output of the encoders goes to the decoder through the cross-attention layers
            self.transcript_decoder_layers = cfg.get('transcript_decoder_layers', [3, 4, 5, 6, 7, 8])
            self.context_decoder_layers = cfg.get(
                'context_decoder_layers', [0, 1, 2, 9, 10, 11]
            )  # For backward compatibility
            multi_encoder_mapping = [None for _ in range(self.decoder.n_layers)]
            for layer in self.transcript_decoder_layers:
                multi_encoder_mapping[layer] = 0  # 0 means text goes to this layer, 1 means context goes to this layer
            for layer in self.context_decoder_layers:
                multi_encoder_mapping[layer] = 1
            self.multi_encoder_mapping = multi_encoder_mapping
            self.context_encoder = transformer_2501.Transformer(**dict(cfg.context_encoder))
        elif self.model_type == 'decoder_context_tts':
            # Context audio/text goes directly to the decoder (before the target audio codes)
            self.transcript_decoder_layers = [
                idx for idx in range(self.decoder.n_layers)
            ]  # All layers are used for text
        elif self.model_type == 'decoder_ce':
            # Similar to decoder_context_tts, but we use context encoder
            # Decoder gets output from context encoder instead of raw context tokens embeddings
            self.context_encoder = transformer_2501.Transformer(**dict(cfg.context_encoder))
            self.transcript_decoder_layers = [
                idx for idx in range(cfg.decoder.n_layers)
            ]  # All layers are used for text
            # Register buffers for baked context embedding (initially None/empty)
            # These will be populated when loading a checkpoint with baked embedding
            self.register_buffer('baked_context_embedding', None)
            self.register_buffer('baked_context_embedding_len', None)
        else:
            raise ValueError(f"Unsupported model type {self.model_type}")

        self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
        self.alignment_loss_scale = cfg.get('alignment_loss_scale', 0.0)
        self.alignment_encoder_loss_scale = cfg.get('alignment_encoder_loss_scale', 0.0)
        if self.alignment_loss_scale > 0.0:
            self.alignment_loss = ForwardSumLoss(loss_scale=self.alignment_loss_scale)
        if self.alignment_encoder_loss_scale > 0.0:
            self.alignment_encoder_loss = ForwardSumLoss(loss_scale=self.alignment_encoder_loss_scale)

        # Define cfg parameters into self parameters
        self.prior_end_step = self.cfg.prior_end_step
        self.prior_scaledown_start_step = self.cfg.prior_scaledown_start_step
        self.indefinite_prior_prob = self.cfg.get('indefinite_prior_prob', 0.0)
        self.ctc_prior_layer_ids = self.cfg.get('ctc_prior_layer_ids', self.transcript_decoder_layers)
        self.cfg_unconditional_prob = self.cfg.get('cfg_unconditional_prob', 0.0)
        self.decoder_input_dropout_prob = self.cfg.get('decoder_input_dropout_prob', 0.0)
        self.binarize_attn_method = self.cfg.get('binarize_attn_method', 'argmax')
        self.binarize_repeat_audio_factor = self.cfg.get('binarize_repeat_audio_factor', 2)
        self.prior_future_decay = self.cfg.get('prior_future_decay', 1.0)
        self.prior_past_decay = self.cfg.get('prior_past_decay', 1.0)
        self.binarized_prior_epsilon = self.cfg.get('binarized_prior_epsilon', 0.0)
        self.prior_future_context = self.cfg.get('prior_future_context', 1)
        self.prior_past_context = self.cfg.get('prior_past_context', 1)
        self.binarize_prior_after_step = self.cfg.get('binarize_prior_after_step', 0)
        self.codebook_loss_scale = self.cfg.get('codebook_loss_scale', 1.0)
        self.local_transformer_loss_scale = self.cfg.get('local_transformer_loss_scale', 1.0)
        self.use_alignment_encoder = self.cfg.get('use_alignment_encoder', False)
        self.use_prior_for_aligner = self.cfg.get('use_prior_for_aligner', False)
        self.aligner_encoder_train_steps = self.cfg.get('aligner_encoder_train_steps', float('inf'))
        self.dec_random_input_max = self.cfg.get('dec_random_input_max', self.num_all_tokens_per_codebook)

        # Configuration validity checks
        self.check_frame_stacking_config_validity()

    def state_dict(self, destination=None, prefix='', keep_vars=False):
        """
        Only used for saving checkpoints. On save, we remove _speaker_verification_model and _codec_model
        from the checkpoint. The codec model is saved in a separate checkpoint.

        _speaker_verification_model is only included in older checkpoints with the older single_encoder_sv_tts
        model_type that is no longer supported and can likely be removed in a future version.

        If the model has a baked context embedding, the context_encoder weights are also excluded
        since they are no longer needed for inference.
        """
        if hasattr(self, '_no_state_dict') and self._no_state_dict:
            return {}
        # Don't save the speaker verification and codec model in the state dict
        state_dict = super().state_dict(destination, prefix, keep_vars)
        keys_substrings_to_exclude = ['_speaker_verification_model', '_codec_model']
        # If we have a baked context embedding, exclude context_encoder weights
        if self.has_baked_context_embedding:
            keys_substrings_to_exclude.append('context_encoder')
        for key in list(state_dict.keys()):
            if any([substring in key for substring in keys_substrings_to_exclude]):
                del state_dict[key]
        return state_dict

    def check_frame_stacking_config_validity(self):
        """
        Check if the configuration is compatible with frame stacking.
        """
        if self.frame_stacking_factor > 1:
            # The settings below are not supported with frame stacking.
            # Some of them may work - but they have not been tested.

            # disallow alignment encoder
            if self.use_alignment_encoder:
                raise ValueError("Alignment encoder is not supported for frame stacking")
            # disallow alignment loss
            if self.alignment_loss_scale > 0.0:
                raise ValueError("Alignment loss is not supported for frame stacking")
            # disallow training prior
            if self.cfg.prior_scaling_factor is not None and self.cfg.prior_scaling_factor > 0:
                raise ValueError("Training-time attention prior is not supported for frame stacking")
            # disallow text conditioning
            if self.use_text_conditioning_encoder:
                raise ValueError("Text conditioning is not supported for frame stacking")

    @property
    def has_baked_context_embedding(self) -> bool:
        """Check if the model has a baked context embedding.

        Returns:
            True if baked_context_embedding buffer is set, not None, and has elements.
        """
        return (
            self.model_type == 'decoder_ce'
            and hasattr(self, 'baked_context_embedding')
            and self.baked_context_embedding is not None
            and self.baked_context_embedding.numel() > 0
        )

    def update_ckpt(self, state_dict):
        """
        Backward compatibility for checkpoints saved with old model names.
        """
        new_state_dict = {}
        for key in state_dict.keys():
            if 't5_encoder' in key:
                new_key = key.replace('t5_encoder', 'encoder')
                new_state_dict[new_key] = state_dict[key]
            elif 't5_decoder' in key:
                new_key = key.replace('t5_decoder', 'decoder')
                new_state_dict[new_key] = state_dict[key]
            else:
                new_state_dict[key] = state_dict[key]
        return new_state_dict

    def load_state_dict(self, state_dict, strict=True):
        """
        Modify load_state_dict so that we don't restore weights to _speaker_verification_model and _codec_model when
        strict is True.
        When strict is False, we can call pytorch's load_state_dict.
        When strict is True, we loop through all parameters and rename them to enable loading.

        _speaker_verification_model is only included in older checkpoints with the older single_encoder_sv_tts
        model_type that is no longer supported and can likely be removed in a future version.

        Also handles loading baked context embeddings. If the checkpoint contains baked_context_embedding,
        context_encoder weights are not expected to be present.
        """
        state_dict = self.update_ckpt(state_dict)

        # Check if checkpoint has baked context embedding
        has_baked_embedding_in_ckpt = (
            'baked_context_embedding' in state_dict and state_dict['baked_context_embedding'] is not None
        )

        # Load baked embedding buffers if present
        if has_baked_embedding_in_ckpt:
            self.baked_context_embedding = state_dict['baked_context_embedding']
            self.baked_context_embedding_len = state_dict['baked_context_embedding_len']
            logging.info(
                f"Loaded baked context embedding with shape {self.baked_context_embedding.shape}, "
                f"length {self.baked_context_embedding_len.item()}"
            )

        if not strict:
            super().load_state_dict(state_dict, strict=False)

        # Build list of modules to skip
        modules_to_skip = [
            '_speaker_verification_model',
            '_codec_model',
            '_reference_model',
            'eval_asr_model',
            'eval_speaker_verification_model',
            'whisper_model',
            'squim_objective_model',
        ]
        # Skip context_encoder if checkpoint has baked embedding (weights won't be in checkpoint)
        if has_baked_embedding_in_ckpt:
            modules_to_skip.append('context_encoder')

        for name, child in self.named_children():
            if name in modules_to_skip:
                continue
            if any(param.numel() > 0 for param in child.parameters()):
                # If the module has parameters, we want to change the default mapping so that the state_dict gets
                # loaded.
                # Ex: state_dict[encoder.position_embeddings.weight] -> new_state_dict[position_embeddings.weight]
                new_state_dict = {}
                for key in state_dict.keys():
                    name_with_dot = f"{name}."
                    if key.startswith(name_with_dot):
                        new_state_dict[key[len(name_with_dot) :]] = state_dict[key]
                child.load_state_dict(new_state_dict)

    def audio_to_codes(self, audio, audio_len, audio_type='target'):
        # audio: (B, T)
        # audio_len: (B,)
        if audio_type == 'target':
            audio_eos_id = self.audio_eos_id
            audio_bos_id = self.audio_bos_id
        elif audio_type == 'context':
            audio_eos_id = self.context_audio_eos_id
            audio_bos_id = self.context_audio_bos_id
        else:
            raise ValueError(f"Received audio_type of {audio_type}. Must be `target` or `context`")

        self._codec_model.eval()
        with torch.no_grad(), torch.autocast(device_type=audio.device.type, dtype=torch.float32):
            codes, codes_len = self._codec_model.encode(audio=audio, audio_len=audio_len)
            if self._codec_converter is not None:
                codes = self._codec_converter.convert_original_to_new(audio_tokens=codes, audio_lens=codes_len)
            # Add a timestep to begining and end of codes tensor
            bos_tensor = torch.full(
                (codes.size(0), codes.size(1), 1), audio_bos_id, dtype=codes.dtype, device=codes.device
            )
            # pad at the end to make room for the EOS token; the EOS token's actual position
            # varies per batch element depending on each element's length.
            pad_tensor = torch.full(
                (codes.size(0), codes.size(1), 1), 0, dtype=codes.dtype, device=codes.device
            )  # 0 is the padding token in the audio codebook
            codes = torch.cat([bos_tensor, codes, pad_tensor], dim=-1)
            # codes: (B, C, T')
            # codes_len: (B,)
            for idx in range(codes.size(0)):
                codes[idx, :, codes_len[idx] + 1] = audio_eos_id
            codes_len = codes_len + 2  # +1 for bos and +1 for eos
            return codes.long(), codes_len.long()

    def codes_to_audio(self, codes, codes_len):
        # codes: (B, C, T')
        # codes_len: (B,)
        self._codec_model.eval()
        with torch.no_grad(), torch.autocast(device_type=codes.device.type, dtype=torch.float32):
            # Make a copy to avoid modifying the original tensor if it's used elsewhere
            codes_copy = codes.clone()
            # Replace eos and bos tokens with padding in the copied tensor
            codes_copy[codes == self.audio_bos_id] = 0  # zero is the padding token
            codes_copy[codes == self.audio_eos_id] = 0
            # Pass the modified integer token IDs
            if self._codec_converter is not None:
                codes_copy = self._codec_converter.convert_new_to_original(
                    audio_tokens=codes_copy, audio_lens=codes_len
                )
            audio, audio_len = self._codec_model.decode(tokens=codes_copy, tokens_len=codes_len)
            # audio: (B, T)
            # audio_len: (B,)
            return audio, audio_len

    def embed_audio_tokens(self, audio_tokens):
        B, C, T = audio_tokens.shape
        audio_embedding = None
        for i in range(self.frame_stacking_factor):
            for c in range(C):
                tokens = audio_tokens[:, c, i :: self.frame_stacking_factor]
                embedding = self.audio_embeddings[c + i * C](tokens)
                if audio_embedding is None:
                    audio_embedding = embedding
                else:
                    audio_embedding += embedding
        audio_embedding = audio_embedding / (C * self.frame_stacking_factor)
        return audio_embedding

    def compute_local_transformer_logits(self, dec_out, audio_codes_target, targets_offset_by_one=False):
        """
        Predicts the logits for all codebooks using the local transformer. Used in both autoregressive (AR) and MaskGit (MG) modes.
        This function is used in training and validation, not inference/sampling.
        The sequence layout is slightly different between AR and MG modes, as shown in the diagram below,
        (using an 8-codebook setup as an example):
        +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+
        | AR target  |    0    |    1    |    2    |    3    |    4    |    5    |    6    |    7    |   none  |
        | codebook   |         |         |         |         |         |         |         |         |         |
        +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+
        | MG target  |  none   |    0    |    1    |    2    |    3    |    4    |    5    |    6    |    7    |
        | codebook   |         |         |         |         |         |         |         |         |         |
        +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+
        |  input     | Magpie  |    0    |    1    |    2    |    3    |    4    |    5    |    6    |    7    |
        |  codebook  | latent  | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK |
        +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+
        | seq. index |    0    |    1    |    2    |    3    |    4    |    5    |    6    |    7    |    8    |
        +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+

        dec_out: (B, T', E)
        audio_codes_target: (B, C, T')
        targets_offset_by_one: bool, if False, the target for index 0 is codebook 0, for index 1 is codebook 1, etc. (autoregressive)
                                     if True,  the target for index 1 is codebook 0, for index 2 is codebook 1, etc. (MaskGit)
        """
        C = self.num_audio_codebooks
        dec_out_all = dec_out.reshape(-1, dec_out.size(-1))  # (B*T', E)
        local_transformer_input = [dec_out_all]
        # Build the teacher-forced input to the LT.
        for fs_index in range(self.frame_stacking_factor):
            for codebook_num in range(C):
                # Collect ground truth codes for the current codebook and frame stack index combintation.
                codes = audio_codes_target[:, codebook_num, fs_index :: self.frame_stacking_factor]  # (B, T')
                # Individual timesteps are independently handled by the LT fold time into the batch dimension.
                codes = codes.reshape(-1)  # (B*T',)
                # Embed the codes
                codebook_embedding = self.audio_embeddings[codebook_num + fs_index * C](codes)  # (B*T', E)
                local_transformer_input.append(codebook_embedding)
        # Stack the input codes along dimension 1 (codebooks). This is the dimension along which the LT predicts iteratively.
        local_transformer_input = torch.stack(local_transformer_input, dim=1)  # (B*T', C+1, E)
        local_transformer_input = self.local_transformer_in_projection(local_transformer_input)  # (B*T', C+1, 128)
        _mask = torch.ones(
            local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device
        )
        local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output']  # (B*T', C+1, E)
        if not targets_offset_by_one:
            # for autoregressive local transformer the target for index 0 is codebook 0, for index 1 is codebook 1, etc.
            local_transformer_output = local_transformer_output[:, :-1, :]  # (B*T', C, E)
        else:
            # for MaskGit the target for index **1** is codebook 0, for index 2 is codebook 1, etc.
            local_transformer_output = local_transformer_output[:, 1:, :]  # (B*T', C, E)
        all_code_logits = []
        for fs_index in range(self.frame_stacking_factor):
            for codebook_num in range(audio_codes_target.size(1)):
                # Using a separate projection layer for each codebook (to distinguish between them)
                # Checked the time - this loop is not taking much time (compared to the local transformer forward pass)
                codebook_logits = self.local_transformer_out_projections[codebook_num + fs_index * C](
                    local_transformer_output[:, codebook_num + fs_index * C, :]
                )  # (B*T', num_all_tokens_per_codebook)
                all_code_logits.append(codebook_logits)
        all_code_logits = torch.cat(
            all_code_logits, dim=1
        )  # (B*T'/frame_stacking_factor, num_codebooks * num_all_tokens_per_codebook * frame_stacking_factor)

        all_code_logits = all_code_logits.view(
            audio_codes_target.size(0), audio_codes_target.size(2) // self.frame_stacking_factor, -1
        )  # (B, T'/frame_stacking_factor, C * num_all_tokens_per_codebook * frame_stacking_factor)

        return all_code_logits

    def maskgit_create_random_mask(self, codes):
        """
        Creates a mask where True indicates the positions that should be replaced with a MASK_TOKEN.
        """
        # Codes: (B, C, T)
        B, C, T = codes.shape
        # get a uniform random vector uniformly sampled from [0,1) ## Todo does it need to be inclusive on the right?
        rand_values = torch.rand(B, T, device=codes.device)
        # apply the cosine schedule
        frac_masked = cosine_schedule(rand_values)
        # how many positions to mask
        n_masked = torch.ceil(frac_masked * C).long()  # B,T
        # The code further below is the vectorized version of this:
        #  for b in range(B):
        #      for t in range(T):
        #          if n_masked[b,t] > 0:
        #              # get a random permutation of the codebook indices
        #              perm = torch.randperm(C)
        #              # mask the top n_masked positions
        #              mask[b, perm[:n_masked[b,t]], t] = True
        #
        # Create random permutations
        random_permutations = torch.argsort(torch.rand(B, C, T, device=codes.device), dim=1)  # (B, C, T)
        # Create a mask tensor where each position indicates if it should be masked
        mask_indices = torch.arange(C, device=codes.device).view(1, C, 1)
        mask = mask_indices < n_masked.view(B, 1, T)  # (B, C, T)
        # Apply the random permutations to the mask
        mask = torch.gather(mask, 1, random_permutations)

        return mask  # (B, C, T)

    def maskgit_apply_random_mask(self, codes):
        # Randomly replaces some codes with the MASK_TOKEN with a proportion following the cosine schedule.
        # Codes: (B, C, T)
        mask = self.maskgit_create_random_mask(codes)
        # replace some tokens with MASK_TOKEN
        codes_with_mask = torch.where(mask, self.mask_token_id, codes)
        return codes_with_mask, mask

    def compute_loss(self, logits, audio_codes, audio_codes_lens, mask_tokens_mask=None, frame_stacking_factor=1):
        """
        Computes the audio codebook loss. Used by
        (1) The main Magpie-TTS transformer
        (2) The local transformer, for both autoregressive and MaskGit methods

        logits: (B, T', num_codebooks * num_tokens_per_codebook)
        audio_codes: (B, C, T')
        audio_codes_lens: (B,)
        mask_tokens_mask: (B, C, T') True for tokens that were replaced with the MASK_TOKEN and should
                                     therefore be the only ones included in the loss computation (for MaskGit).
        frame_stacking_factor: int, the stacking factor used in the model
        """
        loss_mask = get_mask_from_lengths(audio_codes_lens, pad_to_factor=frame_stacking_factor)
        if mask_tokens_mask is not None:
            # For MaskGit we only compute loss for the masked tokens.
            # *Both* conditions must be true:
            # 1. the token is masked
            # 2. the token is not padding
            loss_mask = loss_mask.unsqueeze(1) * mask_tokens_mask
            if not loss_mask.any():
                # Without this we were very rarely getting NaNs in the loss
                logging.warning("No tokens valid were found in compute_loss()!")
                return torch.tensor(0.0, device=loss_mask.device), loss_mask
        else:
            # repeat loss mask for each codebook to simplify code below
            loss_mask = loss_mask.unsqueeze(1).repeat(1, audio_codes.size(1), 1)
        total_codebook_loss = None
        for fs_index in range(frame_stacking_factor):
            for codebook in range(audio_codes.size(1)):
                si = (codebook + self.num_audio_codebooks * fs_index) * self.num_all_tokens_per_codebook
                ei = si + self.num_all_tokens_per_codebook
                codebook_logits = logits[:, :, si:ei]  # (B, T', num_tokens_per_codebook)
                codebook_targets = audio_codes[:, codebook, fs_index::frame_stacking_factor]  # (B, T')
                codebook_loss = self.cross_entropy_loss(
                    codebook_logits.permute(0, 2, 1), codebook_targets  # (B, num_tokens_per_codebook, T')
                )  # (B, T')
                codebook_loss_mask = loss_mask[:, codebook, fs_index::frame_stacking_factor]
                codebook_loss = codebook_loss * codebook_loss_mask
                if codebook_loss_mask.sum() == 0:
                    logging.warning(f"Loss mask for codebook {codebook} is all zeros, global_step: {self.global_step}")
                    continue
                codebook_loss = codebook_loss.sum() / codebook_loss_mask.sum()
                if total_codebook_loss is None:
                    total_codebook_loss = codebook_loss
                else:
                    total_codebook_loss = total_codebook_loss + codebook_loss

        total_codebook_loss = total_codebook_loss / (audio_codes.size(1) * frame_stacking_factor)
        return total_codebook_loss, loss_mask

    def forward(self, dec_input_embedded, dec_input_mask, cond, cond_mask, attn_prior, multi_encoder_mapping):
        decoder_out = self.decoder(
            dec_input_embedded,
            dec_input_mask,
            cond=cond,
            cond_mask=cond_mask,
            attn_prior=attn_prior,
            multi_encoder_mapping=multi_encoder_mapping,
        )
        attn_probabilities = decoder_out['attn_probabilities']
        all_code_logits = self.final_proj(decoder_out['output'])  # (B, T', num_codebooks * num_tokens_per_codebook)
        return all_code_logits, attn_probabilities, decoder_out['output']

    def logits_to_audio_codes(self, all_code_logits, audio_codes_lens):
        # all_code_logits: (B, T', num_codebooks * num_tokens_per_codebook)
        # audio_codes_lens: (B,)
        all_preds = [[] for _ in range(self.frame_stacking_factor)]
        for fs_index in range(self.frame_stacking_factor):
            for idx in range(self.num_audio_codebooks):
                si = (idx + self.num_audio_codebooks * fs_index) * self.num_all_tokens_per_codebook
                ei = si + self.num_all_tokens_per_codebook
                codebook_logits = all_code_logits[:, :, si:ei]
                codebook_probs = torch.softmax(codebook_logits, dim=-1)  # (B, T', num_tokens_per_codebook)
                # argmax to get the tokens
                codebook_preds = torch.argmax(codebook_probs, dim=-1)  # (B, T')
                all_preds[fs_index].append(codebook_preds)
        all_preds = [
            torch.stack(p, dim=1) for p in all_preds
        ]  # list of `frame_stacking_factor`` elements of shape (B,C,T) each
        all_preds = torch.stack(all_preds, dim=-1)  # B, C, T, frame_stacking_factor
        # undo the frame stacking
        all_preds = all_preds.reshape(all_preds.size(0), all_preds.size(1), -1)  # B, C, T*frame_stacking_factor
        pred_max_len = all_preds.size(2)
        real_max_len = audio_codes_lens.max()
        assert (pred_max_len - real_max_len) < self.frame_stacking_factor
        # trim padding introduced for frame stacking
        all_preds = all_preds[:, :, :real_max_len]
        audio_mask = get_mask_from_lengths(audio_codes_lens)
        all_preds = all_preds * audio_mask.unsqueeze(1)

        return all_preds

    def visualize_codes(self, codes, mask_id=2020, frame_stacking_rate=2):
        """
        Visualize codes for analysis purposes
        codes: (B, C)
        """

        def code_to_str(code):
            if code == mask_id:
                return "M    "
            else:
                return f"{code:04d} "

        B, C = codes.shape
        if B > 1:
            logging.debug("Warning: visualizing only first batch element")
        codes = codes.clone().detach().cpu().numpy()[0]
        codes = [code_to_str(c) for c in codes]
        output_str = ""
        for i, c in enumerate(codes):
            if (i) % (C / frame_stacking_rate) == 0:
                output_str += "|timestep| "
            output_str += c
        logging.debug(output_str)

    def clear_forbidden_logits(self, logits: torch.Tensor, forbid_audio_eos: bool = False) -> torch.Tensor:
        """
        Sets logits of forbidden tokens to `-inf` so they will never be sampled.
        Specifically, we forbid sampling of all special tokens except AUDIO_EOS
        which is allowed by default.
        Args:
            logits: (B, C, num_audio_tokens_per_codebook)
            forbid_audio_eos (bool, optional): If True, also forbid AUDIO_EOS tokens
                                               from being sampled. Default: False.
        """
        logits[
            :,
            :,
            SpecialAudioToken.get_forbidden_tokens(self.codebook_size, forbid_audio_eos=forbid_audio_eos),
        ] = float('-inf')
        return logits

    def local_transformer_sample_maskgit(
        self,
        dec_output: torch.Tensor,
        temperature: float = 0.7,
        topk: int = 80,
        unfinished_items: Dict[int, bool] = {},
        finished_items: Dict[int, bool] = {},
        use_cfg: bool = False,
        cfg_scale: float = 1.0,
        n_steps: int = 3,
        noise_scale: float = 0.0,
        fixed_schedule: Optional[List[int]] = None,
        dynamic_cfg_scale: bool = False,
        sampling_type: Optional[str] = None,
        forbid_audio_eos: bool = False,
    ) -> torch.Tensor:
        """
        Sample audio codes for the current timestep using MaskGit-like iterative
        prediction with the local transformer. If frame-stacking is enabled, the
        codes for all frames in the stack are sampled, treated as one long sequence.

        The MaskGit process starts with all positions masked and iteratively unmasks the
        most confident positions over multiple steps. By "masked" we mean that a
        dedicated MASK token is used (as opposed to attention masking). The LT in this
        case is a non-causal transformer decoder. At each step the model predicts all
        positions at once.  Of those predictions, a subset of the most confident
        previously-masked positions is kept and unmasked in the next step. The number of
        positions that are unmasked at each step is determined by the unmasking
        schedule. We support a cosine schedule and a fixed schedule provided by the
        user.

        Uses multinomial sampling with temperature, top-k, and classifier-free guidance (CFG).

        Special handling:
        * forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled
        * forces / forbids EOS for finished / unfinished items respectively
        * optionally, globally forbids audio EOS for all items in the batch.
          This is useful early in the generation process.
        * supports different unmasking methods, see `sampling_type` argument for details.

        Args:
            dec_output (torch.Tensor): Decoder output tensor with shape (B, E) where B is batch size
                and E is primary decoder's embedding dimension.
            temperature (float, optional): Sampling temperature
            topk (int, optional): Number of top-probability tokens to consider in sampling.
            unfinished_items (dict, optional): Dictionary containing indices of batch
                items that we are confident have not completed generation. For these items, audio EOS
                sampling is forbidden.
            finished_items (dict, optional): Dictionary containing indices of batch
                items that we are confident are completed. For these items, audio EOS sampling
                is forced.
            use_cfg (bool, optional): Whether to use classifier-free guidance. If True, expects batch size
                to be doubled with conditional and unconditional outputs from the primary decoder.
            cfg_scale (float, optional): Scale factor for classifier-free guidance. Only used if use_cfg=True.
            n_steps (int, optional): Number of iterative refinement steps for MaskGit sampling.
            noise_scale (float, optional): Scale factor for noise to add to confidence scores
                during sampling (experimental).
            fixed_schedule (list, optional): Fixed schedule for number of tokens to unmask at each step.
                If None, uses cosine schedule.
            dynamic_cfg_scale (bool, optional): Whether to dynamically adjust CFG scale during
                sampling (experimental).
            sampling_type (str, optional): Type of sampling strategy. Options are:
             ["default", "causal", "purity_causal", "purity_default"].
             * Purity refers to "purity sampling" from https://arxiv.org/abs/2304.01515. If "purity"
               is not specified, confidence sampling is used as in the original MaskGit paper.
             * "default"/"causal": Controls the order of unmasking across frames when frame-stacking is enabled.
                                   If "causal" is specified, frames are unmasked in causal order. "default"
                                   doesn't impose any constraints on the unmasking order.
            forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire
                batch.

        Returns:
            torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor)
        """
        # dec_output: (B, E)
        device = dec_output.device
        # disable KV cache since our transformer is not causal
        self.local_transformer.reset_cache(use_cache=False)
        dec_output = dec_output.unsqueeze(1)  # (B, 1, E)
        local_transformer_input_init = self.local_transformer_in_projection(
            dec_output
        )  # (B, 1, D) where D is the dimension of the local transformer
        codebook_seq_len = self.num_audio_codebooks * self.frame_stacking_factor
        B = dec_output.size(0)

        min_confidence = 0
        # this needs to be large enough that unmasked items will always remain unmasked (even after noise addition)
        # Setting it smaller could allow "regret", i.e. re-masking a codebook that was previously unmasked; we might want to try that
        max_confidence = 5
        confidences = min_confidence * torch.ones(B, codebook_seq_len, device=device)
        # initialize to all masked
        codes = self.mask_token_id * torch.ones((B, codebook_seq_len), device=device, dtype=torch.long)
        sampled_codes = codes.clone()
        topk_indices = None
        if fixed_schedule is not None:
            n_steps = len(fixed_schedule)
        for step in range(n_steps):
            # how far along we are in the unmasking process
            progress = step / n_steps
            # get mask fraction
            frac_masked = cosine_schedule(torch.tensor(progress))
            if sampling_type == "causal" or sampling_type == "purity_causal":
                frac_masked = torch.ones_like(frac_masked) * (1.0 - progress)
            # how many codebooks to mask
            if fixed_schedule is None:
                n_masked = torch.ceil(codebook_seq_len * frac_masked).long()
            else:
                n_masked = codebook_seq_len - fixed_schedule[step]
            n_unmasked = codebook_seq_len - n_masked

            if (
                sampling_type == "causal" or sampling_type == "purity_causal"
            ):  # and n_unmasked <= self.num_audio_codebooks:
                # force second frame not to be unmasked
                n_frames_to_allow = int(np.floor(progress * self.frame_stacking_factor + 1))
                confidences[:, n_frames_to_allow * self.num_audio_codebooks :] = (
                    min_confidence - 1
                )  # only tested for frame_stacking_factor=2

            # pick top-confidence codebooks up to n_unmasked
            _, topk_indices = torch.topk(confidences, k=n_unmasked, dim=1)
            if use_cfg:
                actual_batch_size = topk_indices.size(0) // 2
                assert (
                    topk_indices[actual_batch_size:] == topk_indices[:actual_batch_size]
                ).all(), "Topk indices are not the same for conditional and unconditional codes"

            # replace masks of the top-k confident codebooks with the codes that were sampled for them
            unmasked_codes = torch.gather(sampled_codes, dim=1, index=topk_indices)
            codes.scatter_(dim=1, index=topk_indices, src=unmasked_codes)

            # build transformer input
            local_transformer_input = local_transformer_input_init
            for codebook_num in range(codebook_seq_len):
                next_local_transformer_input = self.audio_embeddings[codebook_num](codes[:, codebook_num]).unsqueeze(
                    1
                )  # (B, 1, 768)
                next_local_transformer_input = self.local_transformer_in_projection(
                    next_local_transformer_input
                )  # (B, 1, d_local)
                local_transformer_input = torch.cat(
                    [local_transformer_input, next_local_transformer_input], dim=1
                )  # (B, codebook_num+1, d_local)

            # run transformer
            _mask = torch.ones(B, codebook_seq_len + 1, device=device)
            local_transformer_output = self.local_transformer(local_transformer_input, _mask)[
                'output'
            ]  # (B, C+1, d_local)

            # get logits
            logits = []
            for codebook_num in range(codebook_seq_len):
                # The `codebook_num+1` is to drop first position which corresponds to the magpie latent
                codebook_logits = self.local_transformer_out_projections[codebook_num](
                    local_transformer_output[:, codebook_num + 1, :]
                )  # (B, num_audio_tokens_per_codebook)
                logits.append(codebook_logits)
            logits = torch.stack(logits, dim=1)  # (B, C*frame_stacking_factor, num_audio_tokens_per_codebook)

            # apply CFG
            if use_cfg:
                actual_batch_size = logits.size(0) // 2
                conditional_logits = logits[:actual_batch_size]
                unconditional_logits = logits[actual_batch_size:]
                if not dynamic_cfg_scale:
                    current_cfg_scale = cfg_scale
                else:
                    # gradually increase the scale until mid point through sampling, then reduce it again
                    progress = step / (n_steps - 1)
                    # interp = -abs(progress-0.5)+0.5 # increase from 0..1 in the interval from start to midpoint and then go back to zero
                    # interp = 1.0 - progress  # decrease from 1 to 0
                    interp = progress  # gradually increase from 0 to 1
                    current_cfg_scale = (cfg_scale - 1) * interp + 1.0  # 1.0 --> cfg_scale --> 1.0
                cfg_logits = current_cfg_scale * conditional_logits + (1.0 - current_cfg_scale) * unconditional_logits
                logits[:actual_batch_size] = cfg_logits

            # Disallow generation of special tokens
            logits = self.clear_forbidden_logits(logits, forbid_audio_eos=forbid_audio_eos)

            # handle unfinished and finished items
            for item_idx in unfinished_items:
                logits[item_idx, self.audio_eos_id] = float('-inf')
            for item_idx in finished_items:
                logits[item_idx, :, :] = float('-inf')
                logits[item_idx, :, self.audio_eos_id] = 0.0

            # sample with top-k
            logits_topk = torch.topk(logits, topk, dim=-1)[0]  # (B, C, topk)
            indices_to_remove = logits < logits_topk[:, :, -1].unsqueeze(-1)  # (B, C, num_audio_tokens_per_codebook)
            logits_rescored = logits.clone()
            logits_rescored[indices_to_remove] = float('-inf')
            probs = torch.softmax(logits_rescored / temperature, dim=-1)  # (B, C, num_audio_tokens_per_codebook)
            sampled_codes = torch.multinomial(probs.view(B * codebook_seq_len, -1), 1).view(B, codebook_seq_len)
            if use_cfg:
                sampled_codes[actual_batch_size:] = sampled_codes[:actual_batch_size]
                probs[actual_batch_size:] = probs[:actual_batch_size]
            if sampling_type != "purity_causal" and sampling_type != "purity_default":
                confidences = torch.gather(probs, dim=2, index=sampled_codes.unsqueeze(-1)).squeeze(-1)
            else:
                # use the max probability across all tokens for each codebook as the confidence for each codebook; known as "purity sampling"
                confidences = probs.max(dim=2)[0]
            # replace entries in sampled_codes with previously unmasked codebooks
            sampled_codes.scatter_(dim=1, index=topk_indices, src=unmasked_codes)
            #  add noise to confidences (as in token-critic paper, https://arxiv.org/abs/2209.04439)
            if noise_scale > 0.0:
                # get noise from uniform distribution in the interval [-0.5, 0.5), scale it by `noise_scale`,
                # and anneal it to 0 as we approach the end of the unmasking process
                noise = (
                    (torch.rand_like(confidences) - 0.5) * noise_scale * (1 - (step + 2) / n_steps)
                )  # the +2 makes sure that by the last iteration the noise is exactly 0
                confidences += noise
                # the conditional and unconditional get different noise and must be fixed to be the same again
                confidences[actual_batch_size:] = confidences[:actual_batch_size]
            confidence_eps = 0.1
            assert (
                confidences.max() + confidence_eps < max_confidence
            ), f"Predicted confidence is approaching max_confidence: {confidences.max()}"
            # for unmasked codebooks, set confidence to max so that they will remain unmasked
            confidences.scatter_(
                index=topk_indices, dim=1, src=max_confidence * torch.ones_like(topk_indices, dtype=torch.float)
            )
        codes = sampled_codes
        assert not (
            codes == self.mask_token_id
        ).any(), "Codes contain mask tokens after completion of MaskGit sampling"

        # break stacked groups of frames into individual frames
        codes = codes.reshape(B, self.frame_stacking_factor, self.num_audio_codebooks).permute(
            0, 2, 1
        )  # B, C, frame_stacking_factor

        if use_cfg:
            # drop unconditional codes
            codes = codes[:actual_batch_size]
        return codes

    def local_transformer_sample_autoregressive(
        self,
        dec_output: torch.Tensor,
        temperature: float = 0.7,
        topk: int = 80,
        unfinished_items: Dict[int, bool] = {},
        finished_items: Dict[int, bool] = {},
        use_cfg: bool = False,
        cfg_scale: float = 1.0,
        use_kv_cache: bool = True,
        forbid_audio_eos: bool = False,
    ) -> torch.Tensor:
        """
        Sample audio codes autoregressively across codebooks using the local
        transformer. Uses multinomial sampling with temperature, top-k, and
        classifier-free guidance (CFG).

        The sequence is initialized with the primary decoder's hidden output as the only
        input and is gradually extended a code for one codebook at a time, appending the
        sampled code as input sequence for the next step. At the last step the sequence
        is `num_codebooks` long. If frame stacking is enabled, codes for all frames in
        the stack are sampled as one long sequence and the final sequence length is
        `num_codebooks * frame_stacking_factor` codes long.

        Special handling:
        * forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled
        * forces / forbids EOS for finished / unfinished items respectively
        * optionally, globally forbids audio EOS (useful early in the generation process)

        Args:
            dec_output (torch.Tensor): Decoder output tensor with shape (B, E) where B is batch size
                and E is primary decoder's embedding dimension.
            temperature (float, optional): Sampling temperature.
            topk (int, optional): Number of top-probability tokens to consider in sampling.
            unfinished_items (dict, optional): Dictionary containing indices of batch
                items that we are confident have not completed generation. For these items, audio EOS
                sampling is forbidden.
            finished_items (dict, optional): Dictionary containing indices of batch
                items that we are confident are completed. For these items, audio EOS sampling
                is forced.
            use_cfg (bool, optional): Whether to use classifier-free guidance. If True, expects batch size
                to be doubled with conditional and unconditional outputs from the primary decoder.
            cfg_scale (float, optional): Scale factor for classifier-free guidance. Only used if use_cfg=True.
            use_kv_cache (bool, optional): Whether to use key-value caching in the transformer.
            forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire
                batch.

        Returns:
            torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor)
                where B is batch size (or actual_batch_size if use_cfg=True).
        """

        self.local_transformer.reset_cache(use_cache=use_kv_cache)
        dec_output = dec_output.unsqueeze(1)  # (B, 1, E)
        local_transformer_input = self.local_transformer_in_projection(dec_output)  # (B, 1, 128)
        all_preds = []
        for codebook_num in range(self.num_audio_codebooks * self.frame_stacking_factor):
            _mask = torch.ones(
                local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device
            )
            local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output']  # (B, T, 128)
            codebook_logits = self.local_transformer_out_projections[codebook_num](
                local_transformer_output[:, -1, :]
            )  # (B, num_all_tokens_per_codebook)
            if use_cfg:
                actual_batch_size = codebook_logits.size(0) // 2
                conditional_logits = codebook_logits[:actual_batch_size]
                unconditional_logits = codebook_logits[actual_batch_size:]
                cfg_logits = cfg_scale * conditional_logits + (1.0 - cfg_scale) * unconditional_logits
                codebook_logits[:actual_batch_size] = cfg_logits

            for item_idx in unfinished_items:
                codebook_logits[item_idx, self.audio_eos_id] = float('-inf')
            for item_idx in finished_items:
                codebook_logits[item_idx, :] = float('-inf')
                codebook_logits[item_idx, self.audio_eos_id] = 0.0

            # Disallow generation of special tokens
            codebook_logits = self.clear_forbidden_logits(
                codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos
            ).squeeze(1)

            codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0]  # (B, topk)
            indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze(
                -1
            )  # (B, num_tokens_per_codebook)
            codebook_logits_rescored = codebook_logits.clone()
            codebook_logits_rescored[indices_to_remove] = float('-inf')
            codebook_probs = torch.softmax(
                codebook_logits_rescored / temperature, dim=-1
            )  # (B, num_tokens_per_codebook)
            codebook_preds = torch.multinomial(codebook_probs, 1)  # (B, 1)
            if use_cfg:
                codebook_preds[actual_batch_size:] = codebook_preds[:actual_batch_size]
            all_preds.append(codebook_preds)
            next_local_transformer_input = self.audio_embeddings[codebook_num](codebook_preds.squeeze(-1)).unsqueeze(
                1
            )  # (B, 1, 128)
            next_local_transformer_input = self.local_transformer_in_projection(
                next_local_transformer_input
            )  # (B, 1, 128)
            local_transformer_input = torch.cat(
                [local_transformer_input, next_local_transformer_input], dim=1
            )  # (B, T+1, 128)

        all_preds = torch.cat(all_preds, dim=1).long()  # (B, num_codebooks * frame_stacking_factor)
        all_preds = all_preds.reshape(-1, self.frame_stacking_factor, self.num_audio_codebooks).permute(
            0, 2, 1
        )  # (B, num_codebooks, frame_stacking_factor)
        if use_cfg:
            all_preds = all_preds[:actual_batch_size]

        return all_preds

    def sample_codes_from_logits(
        self,
        all_code_logits_t: torch.Tensor,
        temperature: float = 0.7,
        topk: int = 80,
        unfinished_items: Dict[int, bool] = {},
        finished_items: Dict[int, bool] = {},
        forbid_audio_eos: bool = False,
    ) -> torch.Tensor:
        """
        Sample codes for all codebooks at a given timestep. Uses multinomial sampling
        with temperature and top-k. If frame stacking is on (i.e. `frame_stacking_factor
        > 1`), this function will sample across the entire frame stack.

        Special handling:
        * forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled
        * forces / forbids EOS for finished / unfinished items respectively
        * optionally, globally forbids audio EOS (useful early in the generation process)

        Args:
            all_code_logits_t (torch.Tensor): Logits at a given timestep with shape
                (B, num_tokens_per_codebook * num_codebooks * frame_stacking_factor)
            temperature (float, optional): Sampling temperature
            topk (int, optional): Number of top-probability tokens to consider in sampling.
            unfinished_items (dict, optional): Dictionary containing indices of batch
            items that we are confident have not completed generation. For these items, audio EOS
                sampling is forbidden.
            finished_items (dict, optional): Dictionary containing indices of batch
                items that we are confident are completed. For these items, audio EOS sampling
                is forced.
            forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire
                batch.

        Returns:
            torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor).
        """
        all_preds = [[] for _ in range(self.frame_stacking_factor)]
        for fs_index in range(self.frame_stacking_factor):
            for idx in range(self.num_audio_codebooks):
                si = (idx + self.num_audio_codebooks * fs_index) * self.num_all_tokens_per_codebook
                ei = si + self.num_all_tokens_per_codebook
                codebook_logits = all_code_logits_t[:, si:ei]  # (B, num_tokens_per_codebook)

                for item_idx in unfinished_items:
                    codebook_logits[item_idx, self.audio_eos_id] = float('-inf')
                for item_idx in finished_items:
                    codebook_logits[item_idx, :] = float('-inf')
                    codebook_logits[item_idx, self.audio_eos_id] = 0.0

                # Disallow generation of special tokens
                codebook_logits = self.clear_forbidden_logits(
                    codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos
                ).squeeze(1)

                codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0]  # (B, topk)
                indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze(
                    -1
                )  # (B, num_tokens_per_codebook)
                codebook_logits_rescored = codebook_logits.clone()
                codebook_logits_rescored[indices_to_remove] = float('-inf')

                codebook_probs = torch.softmax(
                    codebook_logits_rescored / temperature, dim=-1
                )  # (B, num_tokens_per_codebook)
                codebook_preds = torch.multinomial(codebook_probs, 1)  # (B, 1)
                all_preds[fs_index].append(codebook_preds)

        all_preds = [
            torch.cat(ds_preds, dim=1).long() for ds_preds in all_preds
        ]  # list of `frame_stacking_factor` elements, each of shape (B, num_codebooks)
        all_preds = torch.stack(all_preds, dim=2)  # (B, num_codebooks, frame_stacking_factor)
        return all_preds

    def log_attention_probs(self, attention_prob_matrix, audio_codes_lens, text_lens, prefix="", dec_context_size=0):
        # attention_prob_matrix List of (B, C, audio_timesteps, text_timesteps)
        wandb_images_log = {}

        with torch.no_grad():
            attention_prob_matrix = torch.cat(attention_prob_matrix, dim=1)  # (B, C, audio_timesteps, text_timesteps)
            attention_prob_matrix_mean = attention_prob_matrix.mean(dim=1)  # (B, audio_timesteps, text_timesteps)

            for logger in self.loggers:
                is_wandb = isinstance(logger, WandbLogger)
                is_tb = isinstance(logger, TensorBoardLogger)
                if not is_wandb and not is_tb:
                    raise ValueError(
                        f"Invalid logger type for image logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported."
                    )

                wandb_images_log[f"Image/{prefix}/attention_matrix"] = list()
                for idx in range(min(3, attention_prob_matrix_mean.size(0))):
                    item_attn_matrix = attention_prob_matrix_mean[idx][
                        dec_context_size : dec_context_size + audio_codes_lens[idx], : text_lens[idx]
                    ]
                    item_attn_matrix = item_attn_matrix.detach().cpu().numpy()
                    img_np = plot_alignment_to_numpy(item_attn_matrix.T)

                    if is_wandb:
                        wandb_images_log[f"Image/{prefix}/attention_matrix"].append(
                            wandb.Image(img_np, caption=f"Example_{idx}")
                        )

                    if is_tb:
                        logger.experiment.add_image(
                            f'{prefix}/attention_matrix/Example_{idx}',
                            img_np,
                            global_step=self.global_step,
                            dataformats="HWC",
                        )

        return wandb_images_log

    def log_val_audio_example(
        self,
        logits,
        target_audio_codes,
        audio_codes_lens_target,
        context_audio_codes=None,
        context_audio_codes_lens=None,
    ):
        wandb_audio_log = {}

        pred_audio_codes = self.logits_to_audio_codes(logits, audio_codes_lens_target)
        pred_audio, pred_audio_lens = self.codes_to_audio(pred_audio_codes, audio_codes_lens_target)
        target_audio, target_audio_lens = self.codes_to_audio(target_audio_codes, audio_codes_lens_target)

        context_audio, context_audio_lens = None, None
        if context_audio_codes is not None and context_audio_codes.shape[2] > 3:
            # > 3 ensures, it is a valid context audio tensor (and not dummy tensor used in text context)
            context_audio, context_audio_lens = self.codes_to_audio(context_audio_codes, context_audio_codes_lens)

        for logger in self.loggers:
            is_wandb = isinstance(logger, WandbLogger)
            is_tb = isinstance(logger, TensorBoardLogger)
            if not is_wandb and not is_tb:
                raise ValueError(
                    f"Invalid logger type for audio logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported."
                )

            for idx in range(min(3, pred_audio.size(0))):
                pred_audio_np = pred_audio[idx].float().detach().cpu().numpy()
                target_audio_np = target_audio[idx].float().detach().cpu().numpy()
                pred_audio_np = pred_audio_np[: pred_audio_lens[idx]]
                target_audio_np = target_audio_np[: target_audio_lens[idx]]
                context_audio_np = None
                if context_audio is not None:
                    context_audio_np = context_audio[idx].float().detach().cpu().numpy()
                    context_audio_np = context_audio_np[: context_audio_lens[idx]]

                if is_wandb:
                    wandb_audio_log[f"Audio/Example_{idx}"] = list()
                    if context_audio_np is not None:
                        wandb_audio_log[f"Audio/Example_{idx}"].append(
                            wandb.Audio(context_audio_np, sample_rate=self.sample_rate, caption="context")
                        )
                    wandb_audio_log[f"Audio/Example_{idx}"].append(
                        wandb.Audio(pred_audio_np, sample_rate=self.sample_rate, caption="prediction")
                    )
                    wandb_audio_log[f"Audio/Example_{idx}"].append(
                        wandb.Audio(target_audio_np, sample_rate=self.sample_rate, caption="target")
                    )

                if is_tb:
                    if context_audio_np is not None:
                        logger.experiment.add_audio(
                            f'Example_{idx}/context',
                            context_audio_np,
                            global_step=self.global_step,
                            sample_rate=self.sample_rate,
                        )
                    logger.experiment.add_audio(
                        f'Example_{idx}/prediction',
                        pred_audio_np,
                        global_step=self.global_step,
                        sample_rate=self.sample_rate,
                    )
                    logger.experiment.add_audio(
                        f'Example_{idx}/target',
                        target_audio_np,
                        global_step=self.global_step,
                        sample_rate=self.sample_rate,
                    )

        return wandb_audio_log

    def scale_prior(self, prior, global_step):
        if prior is None:
            return None
        if global_step < self.prior_scaledown_start_step:
            return prior
        elif global_step >= self.prior_end_step:
            if random.random() < self.indefinite_prior_prob:
                print("Using Prior")
                return prior
            else:
                print("Not using Prior")
                return None
        else:
            with torch.no_grad():
                # Interpolate between all ones and the prior
                residual = 1.0 - prior
                new_prior = prior + (
                    residual
                    * (global_step - self.prior_scaledown_start_step)
                    / (self.prior_end_step - self.prior_scaledown_start_step)
                )
                return new_prior

    def embed_text(self, text, text_mask):
        if self.use_bpe_char_tokenizer:
            text_embedded = self.cas_encoder(text, subword_mask=text_mask)
        else:
            text_embedded = self.text_embedding(text)

        return text_embedded

    def compute_alignment_loss(self, attention_scores, text_lens, audio_lens, dec_context_size=0):
        # attention scores: List of (B, C, audio_timesteps, text_timesteps)
        attention_scores_combined = torch.cat(attention_scores, dim=1)  # (B, C, audio_timesteps, text_timesteps)
        attention_scores_mean = attention_scores_combined.mean(
            dim=1, keepdim=True
        )  # (B, 1, audio_timesteps, text_timesteps)
        attention_scores_mean = attention_scores_mean[
            :, :, dec_context_size:, :
        ]  # Remove the context audio embeddings from the attention scores
        alignment_loss = self.alignment_loss(
            attn_logprob=attention_scores_mean, in_lens=text_lens, out_lens=audio_lens
        )
        return alignment_loss

    def pad_audio_codes(self, audio_codes: torch.Tensor, frame_stacking_factor: int = 1, pad_token: int = 0):
        """
        Pads the time dimension of the audio codes to a multiple of the frame stacking factor.
        Args:
            audio_codes (torch.Tensor): B, C, T
            frame_stacking_factor (int): The factor that frames will be stacked by.
            pad_token (int): The token ID to pad with.
        Returns:
            B, C, T_padded
        """
        T = audio_codes.size(2)
        T_padded = int(np.ceil(T / frame_stacking_factor) * frame_stacking_factor)
        if T_padded > T:
            padding = pad_token * torch.ones(
                audio_codes.size(0),
                audio_codes.size(1),
                T_padded - T,
                device=audio_codes.device,
                dtype=audio_codes.dtype,
            )
            audio_codes = torch.cat([audio_codes, padding], dim=2)
        return audio_codes

    def embed_context_text(self, context_text_tokens):
        if self.legacy_text_conditioning:
            context_text_tokens = (
                context_text_tokens - self.tokenizer.tokenizer_offsets[self.text_conditioning_tokenizer_name]
            )
            context_text_embedded = self.context_text_embedding(context_text_tokens)  # (B, L, E)
        else:
            context_text_embedded = self.text_embedding(context_text_tokens)  # (B, L, E)

        return context_text_embedded

    def prepare_context_tensors(self, batch):
        dec_context_size = 0
        additional_decoder_input = None
        additional_decoder_mask = None
        context_audio_codes = None
        context_audio_codes_lens = None
        _attn_prior = None
        attn_prior = None
        cond = None
        cond_mask = None
        multi_encoder_mapping = None
        text = None
        text_lens = None

        # self.model_type must be one of [multi_encoder_context_tts, decoder_context_tts, decoder_ce]
        text = batch['text']
        text_lens = batch['text_lens']
        text_mask = get_mask_from_lengths(text_lens)  # (B, T)
        text_embedded = self.embed_text(text, text_mask)  # (B, T, E)
        text_encoder_out = self.encoder(text_embedded, text_mask, cond=None, cond_mask=None)['output']  # (B, T, E)
        _attn_prior = batch.get('align_prior_matrix', None)
        _attn_prior = self.scale_prior(_attn_prior, self.global_step)

        if self.model_type in ['multi_encoder_context_tts', 'decoder_context_tts', 'decoder_ce']:
            if 'context_audio_codes' in batch:
                context_audio_codes = batch['context_audio_codes']
                context_audio_codes_lens = batch['context_audio_codes_lens']
                if self._codec_converter is not None:
                    context_audio_codes = self._codec_converter.convert_original_to_new(
                        audio_tokens=context_audio_codes, audio_lens=context_audio_codes_lens
                    ).long()
            else:
                context_audio_codes, context_audio_codes_lens = self.audio_to_codes(
                    batch['context_audio'], batch['context_audio_lens'], audio_type='context'
                )
            context_audio_codes = self.pad_audio_codes(context_audio_codes, self.frame_stacking_factor, pad_token=0)
            context_audio_embedded = self.embed_audio_tokens(context_audio_codes)  # (B, T/frame_stacking_factor, E)

            if self.use_text_conditioning_encoder:
                context_text_tokens = batch['context_text_tokens']
                context_text_lens = batch['context_text_tokens_lens']
                context_text_embedded = self.embed_context_text(context_text_tokens)  # (B, L, E)

                # Pad context_audio_embedded or context_text_embedded so that they have same number of timesteps
                if context_audio_embedded.size(1) < context_text_embedded.size(1):
                    padding = torch.zeros(
                        context_audio_embedded.size(0),
                        context_text_embedded.size(1) - context_audio_embedded.size(1),
                        context_audio_embedded.size(2),
                        device=context_audio_embedded.device,
                    )
                    context_audio_embedded = torch.cat([context_audio_embedded, padding], dim=1)
                elif context_audio_embedded.size(1) > context_text_embedded.size(1):
                    padding = torch.zeros(
                        context_text_embedded.size(0),
                        context_audio_embedded.size(1) - context_text_embedded.size(1),
                        context_text_embedded.size(2),
                        device=context_text_embedded.device,
                    )
                    context_text_embedded = torch.cat([context_text_embedded, padding], dim=1)  # (B, T, E)
                has_text_context = batch['has_text_context'].unsqueeze(-1).unsqueeze(-1).float()  # (B, 1, 1)
                context_input_embedded = (
                    has_text_context * context_text_embedded + (1 - has_text_context) * context_audio_embedded
                )
                context_input_lens = (
                    batch['has_text_context'].float() * context_text_lens
                    + (1 - batch['has_text_context'].float()) * context_audio_codes_lens
                )  # (B,)
            else:
                context_input_embedded = context_audio_embedded
                context_input_lens = context_audio_codes_lens
                context_input_lens = torch.ceil(context_input_lens / self.frame_stacking_factor).to(
                    context_input_lens.dtype
                )

            context_mask = get_mask_from_lengths(context_input_lens)

            if self.model_type == 'multi_encoder_context_tts':
                context_embeddings = self.context_encoder(
                    context_input_embedded, context_mask, cond=None, cond_mask=None
                )['output']
                cond = [text_encoder_out, context_embeddings]
                cond_mask = [text_mask, context_mask]
                multi_encoder_mapping = self.multi_encoder_mapping
                attn_prior = [_attn_prior, None]

            elif self.model_type in ['decoder_context_tts', 'decoder_ce']:
                context_embeddings = None  # Address CodeQL
                if self.model_type == 'decoder_context_tts':
                    context_embeddings = context_input_embedded
                elif self.model_type == 'decoder_ce':
                    # Check for baked context embedding first
                    if self.has_baked_context_embedding:
                        # self.baked_context_embedding is a fixed context embedding that is baked into the model.
                        # This is used when we do not want users to generate speech with context audio or context text.
                        # This is done to disable zero-shot inference. Users can only generate speech in 1 voice chosen
                        # by the model development team.
                        batch_size = text.size(0)
                        # Expand baked embedding to batch size: (T, E) -> (B, T, E)
                        context_embeddings = self.baked_context_embedding.unsqueeze(0).expand(batch_size, -1, -1)
                        # Create context mask from baked length
                        context_input_lens = (
                            self.baked_context_embedding_len.unsqueeze(0).expand(batch_size).to(text.device)
                        )
                        context_mask = get_mask_from_lengths(context_input_lens)
                    else:
                        context_embeddings = self.context_encoder(
                            context_input_embedded, context_mask, cond=None, cond_mask=None
                        )['output']
                dec_context_size = context_mask.size(1)
                attn_prior = _attn_prior
                if attn_prior is not None:
                    # B, audio_timesteps, text_timesteps
                    padding_zeros = torch.zeros(
                        attn_prior.size(0), dec_context_size, attn_prior.size(2), device=attn_prior.device
                    )
                    attn_prior = torch.cat([padding_zeros, attn_prior], dim=1)
                cond = text_encoder_out
                cond_mask = text_mask
                multi_encoder_mapping = None
                additional_decoder_input = context_embeddings
                additional_decoder_mask = context_mask
        else:
            raise ValueError(f"Unsupported model type {self.model_type}")

        if attn_prior is not None and self.ctc_prior_layer_ids is not None:
            # Convert prior to a list of tensors, one for each layer
            # Set None for layers not in ctc_prior_layer_ids
            if self.model_type == 'multi_encoder_context_tts':
                text_attn_prior = [
                    attn_prior[0] if layer_idx in self.ctc_prior_layer_ids else None
                    for layer_idx in range(self.decoder.n_layers)
                ]
                attn_prior = [text_attn_prior, attn_prior[1]]
            else:
                attn_prior = [
                    attn_prior if layer_idx in self.ctc_prior_layer_ids else None
                    for layer_idx in range(self.decoder.n_layers)
                ]

        return {
            'beta_binomial_attn_prior': batch.get('align_prior_matrix', None),
            'text_encoder_out': text_encoder_out,
            'cond': cond,
            'cond_mask': cond_mask,
            'attn_prior': attn_prior,
            'prior_used': _attn_prior is not None,
            'multi_encoder_mapping': multi_encoder_mapping,
            'additional_decoder_input': additional_decoder_input,
            'additional_decoder_mask': additional_decoder_mask,
            'dec_context_size': dec_context_size,
            'text': text,
            'text_embedded': text_embedded,
            'text_mask': text_mask,
            'text_lens': text_lens,
            'context_audio_codes': context_audio_codes,
            'context_audio_codes_lens': context_audio_codes_lens,
        }

    def replace_beta_binomial_prior_with_binarized(self, attn_prior, aligner_attn_hard):
        # aligner_attn_hard B, audio_timesteps, text_timesteps
        if self.model_type == 'multi_encoder_context_tts':
            text_attn_prior = attn_prior[0]
        else:
            text_attn_prior = attn_prior

        assert text_attn_prior is not None, "Prior is None"

        if isinstance(text_attn_prior, list):
            # Layer wise prior
            prior_updated = False
            for idx, prior in enumerate(text_attn_prior):
                if prior is not None:
                    text_attn_prior[idx][:, -aligner_attn_hard.size(1) :, :] = aligner_attn_hard
                    prior_updated = True
            assert prior_updated, "Did not find any prior to update"
        else:
            # Same prior for all layers
            text_attn_prior[:, -aligner_attn_hard.size(1) :, :] = aligner_attn_hard

        if self.model_type == 'multi_encoder_context_tts':
            attn_prior[0] = text_attn_prior
        else:
            attn_prior = text_attn_prior

        return attn_prior

    def get_binarized_prior_matrix(self, aligner_attn_soft, audio_lens, text_lens):
        # aligner_attn_soft B, 1, audio_timesteps, text_timesteps
        if self.binarize_attn_method == 'nemo_binarize':
            logging.debug("Binarizing attention using nemo_binarize")
            binarize_repeat_audio_factor = self.binarize_repeat_audio_factor
            aligner_attn_soft_repeated = aligner_attn_soft.repeat_interleave(
                binarize_repeat_audio_factor, dim=2
            )  # B, 1, 2*audio_timesteps, text_timesteps
            aligner_attn_hard = binarize_attention_parallel(
                aligner_attn_soft_repeated, text_lens, audio_lens * binarize_repeat_audio_factor
            ).squeeze(
                1
            )  # B, 2*audio_timesteps, text_timesteps
            aligner_attn_hard = aligner_attn_hard[:, ::2, :]  # B, audio_timesteps, text_timesteps
        elif self.binarize_attn_method == 'argmax':
            logging.debug("Binarizing attention using argmax")
            aligner_attn_hard = torch.argmax(aligner_attn_soft.squeeze(1), dim=-1)
            aligner_attn_hard = torch.nn.functional.one_hot(
                aligner_attn_hard, num_classes=aligner_attn_soft.size(-1)
            ).float()
        else:
            raise ValueError(
                f"self.binarize_attn_method '{self.binarize_attn_method}' must be one of 'nemo_binarize' or 'argmax'."
            )

        aligner_attn_hard_wider = aligner_attn_hard + self.binarized_prior_epsilon

        for future_timestep in range(self.prior_future_context):
            decay_factor = self.prior_future_decay ** (future_timestep + 1)
            aligner_attn_hard_wider[:, :, future_timestep + 1 :] += (
                decay_factor * aligner_attn_hard[:, :, : -(future_timestep + 1)]
            )

        for past_timestep in range(self.prior_past_context):
            decay_factor = self.prior_past_decay ** (past_timestep + 1)
            aligner_attn_hard_wider[:, :, : -past_timestep - 1] += (
                decay_factor * aligner_attn_hard[:, :, past_timestep + 1 :]
            )

        aligner_attn_hard_wider = torch.clamp(aligner_attn_hard_wider, 0.0, 1.0)
        return aligner_attn_hard_wider

    def prepare_dummy_cond_for_cfg(self, cond, cond_mask, additional_decoder_input, additional_dec_mask):
        dummy_additional_decoder_input = None
        dummy_additional_dec_mask = None
        if additional_decoder_input is not None:
            dummy_additional_decoder_input = torch.zeros_like(additional_decoder_input)
            # all ones mask means dont ignore any timesteps (so that it is consistent with usual decoder mask)
            dummy_additional_dec_mask = torch.ones_like(additional_dec_mask)

        if isinstance(cond, list):
            # multi encoder conditioning
            dummy_cond = [torch.zeros_like(cond_item) for cond_item in cond]
            attn_prior = [None for _ in cond]
            dummy_mask = []
            for mask_item in cond_mask:
                # ignore all timesteps except the first one
                mask = torch.zeros_like(mask_item)
                mask[:, 0] = 1  # Make first timestep all zeros
                dummy_mask.append(mask)

        elif isinstance(cond, torch.Tensor):
            # single encoder conditioning
            dummy_cond = torch.zeros_like(cond)
            dummy_mask = torch.zeros_like(cond_mask)
            dummy_mask[:, 0] = 1  # ignore all timesteps except the first one
            attn_prior = None
        else:
            raise ValueError(f"Unsupported type for cond {type(cond)}")

        return dummy_cond, dummy_mask, dummy_additional_decoder_input, dummy_additional_dec_mask, attn_prior

    def process_batch(self, batch, mode="train"):
        context_tensors = self.prepare_context_tensors(batch)
        disable_alignment_loss = False
        if 'audio_codes' not in batch:
            audio_codes, audio_codes_lens = self.audio_to_codes(batch['audio'], batch['audio_lens'])
        else:
            audio_codes = batch['audio_codes']
            audio_codes_lens = batch['audio_codes_lens']
            if self._codec_converter:
                audio_codes = self._codec_converter.convert_original_to_new(
                    audio_tokens=audio_codes, audio_lens=audio_codes_lens
                ).long()
        if self.frame_stacking_factor > 1:
            # repeat the BOS token to frame_stacking_factor times. This is necessary since at inference
            # we need to start autoregressive generation from a full stack indicating BOS.
            # TODO: @rfejgin: this assert might be slow due to GPU/CPU sync
            assert (audio_codes[:, :, 0] == self.audio_bos_id).all(), "Audio codes do not start with BOS token"
            audio_codes = torch.cat(
                [
                    torch.full(
                        (audio_codes.size(0), audio_codes.size(1), self.frame_stacking_factor - 1),
                        self.audio_bos_id,
                        device=audio_codes.device,
                        dtype=audio_codes.dtype,
                    ),
                    audio_codes,
                ],
                dim=2,
            )
            audio_codes_lens += self.frame_stacking_factor - 1  # account for BOS repeat
            audio_codes = self.pad_audio_codes(audio_codes, self.frame_stacking_factor, pad_token=0)
        # Note: if a tensor lacks the `_unstacked` suffix, it can be assumed to to be in the frame-stacked domain

        # drop last (stacked) frame since it is not part of *input*
        audio_codes_input_unstacked = audio_codes[:, :, : -self.frame_stacking_factor]  # B, C, T'
        # drop first (stacked) frame which contains BOS token(s) which are not part of *target*
        audio_codes_target_unstacked = audio_codes[:, :, self.frame_stacking_factor :]
        audio_codes_lens_input_unstacked = audio_codes_lens - 1  # don't count EOS for input
        audio_codes_lens_target_unstacked = audio_codes_lens - self.frame_stacking_factor  # don't count BOS for target
        audio_codes_lens_input = torch.floor(audio_codes_lens_input_unstacked / self.frame_stacking_factor).long()
        audio_codes_embedded_all = self.embed_audio_tokens(
            audio_codes
        )  # (B, T, E) # Computing this to be use in the alignment encoder
        audio_codes_embedded = audio_codes_embedded_all[
            :, :-1, :
        ]  # (B, T', E) Input to the decoder; this is already in the frame-stacked domain, hence the -1 (not `frame_stacking_factor`)

        audio_codes_mask = get_mask_from_lengths(audio_codes_lens_input)
        use_cfg = (self.cfg_unconditional_prob > 0.0) and (mode == "train") and (context_tensors['cond'] is not None)
        if use_cfg and torch.rand(1).item() < self.cfg_unconditional_prob:
            cond, cond_mask, additional_decoder_input, additional_decoder_mask, attn_prior = (
                self.prepare_dummy_cond_for_cfg(
                    context_tensors['cond'],
                    context_tensors['cond_mask'],
                    context_tensors['additional_decoder_input'],
                    context_tensors['additional_decoder_mask'],
                )
            )
            disable_alignment_loss = True
        else:
            cond = context_tensors['cond']
            cond_mask = context_tensors['cond_mask']
            additional_decoder_input = context_tensors['additional_decoder_input']
            additional_decoder_mask = context_tensors['additional_decoder_mask']
            attn_prior = context_tensors['attn_prior']

            if mode == "train" and self.decoder_input_dropout_prob > 0.0 and torch.rand(1).item() < 0.5:
                # For some batches (half of them), replace decoder_input_dropout_prob of the timesteps with random tokens
                max_codebook_val = self.dec_random_input_max
                # @pneekhara: Keeping dec_random_input_max configurable since num_all_tokens_per_codebook usually has padding tokens
                # which can cause errors when doing codes_to_audio for audio_codes_input. We are not currently calling codes_to_audio on
                # audio_codes_input so should not matter if we don't supply dec_random_input_max.
                random_audio_tokens = torch.randint(
                    0, max_codebook_val, audio_codes_input_unstacked.size(), device=audio_codes_input_unstacked.device
                )
                random_audio_tokens = random_audio_tokens * audio_codes_mask.unsqueeze(1)
                dec_dropout_mask = (
                    torch.rand((1, 1, audio_codes_input_unstacked.size(2)), device=audio_codes_input_unstacked.device)
                    > self.decoder_input_dropout_prob
                )
                # timestep_mask is True for timesteps to be kept
                audio_codes_input_unstacked = audio_codes_input_unstacked * dec_dropout_mask + random_audio_tokens * (
                    ~dec_dropout_mask
                )
                audio_codes_embedded = self.embed_audio_tokens(audio_codes_input_unstacked)  # (B, T', E)

        if context_tensors['additional_decoder_input'] is not None:
            dec_input_embedded = torch.cat([additional_decoder_input, audio_codes_embedded], dim=1)
            dec_input_mask = torch.cat([additional_decoder_mask, audio_codes_mask], dim=1)
        else:
            dec_input_embedded = audio_codes_embedded
            dec_input_mask = audio_codes_mask

        aligner_encoder_loss = None
        aligner_attn_soft = None
        aligner_attn_hard = None
        if self.use_alignment_encoder and not disable_alignment_loss:
            aligner_prior = None
            if self.use_prior_for_aligner:
                aligner_prior = context_tensors['beta_binomial_attn_prior']
            # Passing target audio embeddings to the alignment encoder
            if self.global_step < self.aligner_encoder_train_steps:
                aligner_attn_soft, aligner_attn_logprobs = self.alignment_encoder(
                    queries=audio_codes_embedded_all[:, 1:, :].permute(0, 2, 1),  # B, E, T'
                    keys=context_tensors['text_encoder_out'].permute(0, 2, 1),  # B, E, T
                    mask=~context_tensors['text_mask'].unsqueeze(-1),
                    attn_prior=aligner_prior,
                )

                aligner_encoder_loss = self.alignment_encoder_loss(
                    attn_logprob=aligner_attn_logprobs,
                    in_lens=context_tensors['text_lens'],
                    out_lens=audio_codes_lens_input,
                )
            else:
                with torch.no_grad():
                    # Just get the attention matrix without computing the loss or gradients
                    aligner_attn_soft, aligner_attn_logprobs = self.alignment_encoder(
                        queries=audio_codes_embedded_all[:, 1:, :].permute(0, 2, 1),  # B, E, T'
                        keys=context_tensors['text_encoder_out'].permute(0, 2, 1),  # B, E, T
                        mask=~context_tensors['text_mask'].unsqueeze(-1),
                        attn_prior=aligner_prior,
                    )

            with torch.no_grad():
                aligner_attn_hard = self.get_binarized_prior_matrix(
                    aligner_attn_soft, audio_codes_lens_input, context_tensors['text_lens']
                )
                if (self.global_step > self.binarize_prior_after_step) and context_tensors['prior_used']:
                    attn_prior = self.replace_beta_binomial_prior_with_binarized(attn_prior, aligner_attn_hard)

        logits, attn_info, dec_out = self.forward(
            dec_input_embedded=dec_input_embedded,
            dec_input_mask=dec_input_mask,
            cond=cond,
            cond_mask=cond_mask,
            attn_prior=attn_prior,
            multi_encoder_mapping=context_tensors['multi_encoder_mapping'],
        )
        # logits: (B, T', num_codebooks * num_tokens_per_codebook)
        # dec_out: (B, T', E)
        dec_context_size = context_tensors['dec_context_size']
        logits = logits[:, dec_context_size:, :]  # Remove the context audio embeddings from the logits

        # Codebook loss (parallel)
        codebook_loss, loss_mask = self.compute_loss(
            logits,
            audio_codes_target_unstacked,
            audio_codes_lens_target_unstacked,
            frame_stacking_factor=self.frame_stacking_factor,
        )
        # Alignment loss
        alignment_loss = None
        if self.alignment_loss_scale > 0.0 and not disable_alignment_loss:
            text_lens = context_tensors['text_lens']
            cross_attention_scores = [
                attn['cross_attn_probabilities'][1]
                for layer_idx, attn in enumerate(attn_info)
                if layer_idx in self.ctc_prior_layer_ids
            ]
            alignment_loss = self.compute_alignment_loss(
                cross_attention_scores, text_lens, audio_codes_lens_input, dec_context_size
            )
            loss = self.codebook_loss_scale * codebook_loss + alignment_loss
        else:
            loss = self.codebook_loss_scale * codebook_loss

        # Local Transformer loss
        local_transformer_loss = None
        local_transformer_logits = None
        if self.local_transformer_type != LocalTransformerType.NO_LT:
            if self.local_transformer_type == LocalTransformerType.MASKGIT:
                # Maskgit
                # randomly replace some positions with MASK_TOKEN
                audio_codes_masked, mask_tokens_mask = self.maskgit_apply_random_mask(audio_codes_target_unstacked)
                # TODO @rfejgin: the very last position might be padding but the local transformer might look at it as part of
                #                of a pair where the first position is valid. Is this an issue?
                local_transformer_logits = self.compute_local_transformer_logits(
                    dec_out[:, dec_context_size:, :], audio_codes_masked, targets_offset_by_one=True
                )
                local_transformer_loss, _ = self.compute_loss(
                    local_transformer_logits,
                    audio_codes_target_unstacked,
                    audio_codes_lens_target_unstacked,
                    mask_tokens_mask,
                    frame_stacking_factor=self.frame_stacking_factor,
                )
            else:
                # Autoregressive
                assert self.local_transformer_type == LocalTransformerType.AR, "Unexpected local transformer type"
                local_transformer_logits = self.compute_local_transformer_logits(
                    dec_out[:, dec_context_size:, :], audio_codes_target_unstacked, targets_offset_by_one=False
                )
                local_transformer_loss, _ = self.compute_loss(
                    local_transformer_logits,
                    audio_codes_target_unstacked,
                    audio_codes_lens_target_unstacked,
                    None,
                    frame_stacking_factor=self.frame_stacking_factor,
                )
            loss = loss + self.local_transformer_loss_scale * local_transformer_loss

        if aligner_encoder_loss is not None:
            loss = loss + aligner_encoder_loss

        return {
            'logits': logits,
            'attn_info': attn_info,
            'loss': loss,
            'codebook_loss': codebook_loss,
            'local_transformer_loss': local_transformer_loss,
            'local_transformer_logits': local_transformer_logits,
            'loss_mask': loss_mask,
            'alignment_loss': alignment_loss,
            'aligner_encoder_loss': aligner_encoder_loss,
            'audio_codes_target': audio_codes_target_unstacked,
            'audio_codes_lens_target': audio_codes_lens_target_unstacked,
            'text': context_tensors['text'],
            'text_lens': context_tensors['text_lens'],
            'context_audio_codes': context_tensors['context_audio_codes'],
            'context_audio_codes_lens': context_tensors['context_audio_codes_lens'],
            'dec_context_size': dec_context_size,
            'aligner_attn_soft': aligner_attn_soft,
            'aligner_attn_hard': aligner_attn_hard,
        }

    def training_step(self, batch, batch_idx):
        batch_output = self.process_batch(batch)
        loss = batch_output['loss']
        codebook_loss = batch_output['codebook_loss']
        self.log('train/codebook_loss', codebook_loss, prog_bar=True, sync_dist=True)
        if self.cfg_unconditional_prob == 0.0:
            # Only log alignment loss when not using cfg to avoid sync issues when
            # alignment loss is None on some ranks
            alignment_loss = batch_output['alignment_loss']
            if alignment_loss is not None:
                self.log('train/alignment_loss', alignment_loss, prog_bar=True, sync_dist=True)
        self.log('train/loss', loss, prog_bar=True, sync_dist=True)
        local_transformer_loss = batch_output['local_transformer_loss']
        if local_transformer_loss is not None:
            self.log('train/local_transformer_loss', local_transformer_loss, prog_bar=True, sync_dist=True)

        # Log batch info
        batch_size, text_token_max_len = batch["text"].shape
        text_token_total_num = batch["text_lens"].sum()
        batch_info_dict = {
            "train/batch_size": batch_size,
            "train/text_token_max_len": text_token_max_len,
            "train/text_token_total_num_in_batch": text_token_total_num.item(),
            "train/text_token_pad_ratio_percent_in_batch": 100
            * (1 - text_token_total_num / (batch_size * text_token_max_len)),
        }

        if "audio_codes" in batch:
            audio_codes_max_len = batch["audio_codes"].shape[-1]
            audio_codes_total_num = batch["audio_codes_lens"].sum()
            batch_info_dict.update(
                {
                    "train/audio_codes_max_len": audio_codes_max_len,
                    "train/audio_codes_total_num_in_batch": audio_codes_total_num.item(),
                    "train/audio_codes_pad_ratio_percent_in_batch": 100
                    * (1 - audio_codes_total_num / (batch_size * audio_codes_max_len)),
                }
            )
        else:
            audio_samples_max_len = batch["audio"].shape[-1]
            audio_samples_total_num = batch["audio_lens"].sum()
            batch_info_dict.update(
                {
                    "train/audio_samples_max_len": audio_samples_max_len,
                    "train/audio_samples_total_num_in_batch": audio_samples_total_num.item(),
                    "train/audio_samples_pad_ratio_percent_in_batch": 100
                    * (1 - audio_samples_total_num / (batch_size * audio_samples_max_len)),
                }
            )

        self.log_dict(batch_info_dict, on_step=True)

        return loss

    def validation_step(self, batch, batch_idx):
        batch_output = self.process_batch(batch, mode="val")
        # self.process_batch returns a dict. We currently only log "logits" which come from the parallel prediction
        # head. If we use local_transformer, then the local_transformer returns "local_transformer_logits"
        loss = batch_output['loss']
        codebook_loss = batch_output['codebook_loss']
        alignment_loss = batch_output['alignment_loss']
        aligner_encoder_loss = batch_output['aligner_encoder_loss']
        logits = batch_output['logits']
        audio_codes_target = batch_output['audio_codes_target']
        audio_codes_lens_target = batch_output['audio_codes_lens_target']
        context_audio_codes = batch_output['context_audio_codes']
        context_audio_codes_lens = batch_output['context_audio_codes_lens']
        attn_info = batch_output['attn_info']
        text_lens = batch_output['text_lens']
        dec_context_size = batch_output['dec_context_size']
        if alignment_loss is None:
            alignment_loss = torch.tensor(0.0, device=loss.device)
        if aligner_encoder_loss is None:
            aligner_encoder_loss = torch.tensor(0.0, device=loss.device)

        if batch_idx == 0 and self.global_rank == 0:
            # Prepare dictionary for aggregated wandb logging
            wandb_log_dict = {}

            # Get audio data for logging
            wandb_log_dict.update(
                self.log_val_audio_example(
                    logits, audio_codes_target, audio_codes_lens_target, context_audio_codes, context_audio_codes_lens
                )
            )

            # Get attention image data for logging
            if len(attn_info[self.transcript_decoder_layers[0]]['cross_attn_probabilities']) > 1:
                # cross_attn_probabilities only returned when not using flash attention
                cross_attention_probs = [
                    attn['cross_attn_probabilities'][0]
                    for layer_idx, attn in enumerate(attn_info)
                    if layer_idx in self.ctc_prior_layer_ids
                ]
                wandb_log_dict.update(
                    self.log_attention_probs(
                        cross_attention_probs,
                        audio_codes_lens_target,
                        text_lens,
                        prefix="val",
                        dec_context_size=dec_context_size,
                    )
                )

                for layer_idx in self.transcript_decoder_layers:
                    cross_attention_probs = [attn_info[layer_idx]['cross_attn_probabilities'][0]]
                    wandb_log_dict.update(
                        self.log_attention_probs(
                            cross_attention_probs,
                            audio_codes_lens_target,
                            text_lens,
                            prefix=f"val/layer_{layer_idx}",
                            dec_context_size=dec_context_size,
                        )
                    )

                if batch_output['aligner_attn_soft'] is not None:
                    wandb_log_dict.update(
                        self.log_attention_probs(
                            [batch_output['aligner_attn_soft']],
                            audio_codes_lens_target,
                            text_lens,
                            prefix="val/aligner_encoder_attn",
                        )
                    )

                if batch_output['aligner_attn_hard'] is not None:
                    wandb_log_dict.update(
                        self.log_attention_probs(
                            [batch_output['aligner_attn_hard'].unsqueeze(1)],
                            audio_codes_lens_target,
                            text_lens,
                            prefix="val/aligner_encoder_attn_hard",
                        )
                    )

            # Perform single wandb log call if wandb is active and there is data
            for logger in self.loggers:
                if isinstance(logger, WandbLogger) and wandb_log_dict:
                    logger.experiment.log(wandb_log_dict)

        local_transformer_loss = batch_output['local_transformer_loss']
        val_output = {
            'val_loss': loss,
            'val_codebook_loss': codebook_loss,
            'val_alignment_loss': alignment_loss,
            'val_local_transformer_loss': local_transformer_loss,
            'val_aligner_encoder_loss': aligner_encoder_loss,
        }
        self.validation_step_outputs.append(val_output)

        return val_output

    def get_cross_attention_scores(self, attn_probs, filter_layers=None):
        """
        Returns the cross attention probabilities for the last audio timestep
        """
        mean_cross_attn_scores = []
        all_heads_cross_attn_scores = []
        for lidx, layerwise_attn_prob in enumerate(attn_probs):
            if (filter_layers is not None and lidx not in filter_layers) or (
                lidx not in self.transcript_decoder_layers
            ):
                continue
            cross_attn_prob = layerwise_attn_prob['cross_attn_probabilities'][
                0
            ]  # B, H, audio_timesteps, text_timesteps
            mean_cross_attn_scores.append(cross_attn_prob.mean(dim=1))  # B, audio_timesteps, text_timesteps
            for head_idx in range(cross_attn_prob.size(1)):
                all_heads_cross_attn_scores.append(cross_attn_prob[:, head_idx, -1, :])  # B, text_timesteps

        mean_cross_attn_scores = torch.stack(mean_cross_attn_scores, dim=1)  # B, L, audio_timesteps, text_timesteps
        mean_cross_attn_scores = mean_cross_attn_scores.mean(dim=1)  # B, audio_timesteps, text_timesteps
        last_audio_timestep_scores = mean_cross_attn_scores[:, -1, :]  # B, text_timesteps
        return last_audio_timestep_scores, all_heads_cross_attn_scores

    def get_most_attended_text_timestep(
        self,
        alignment_attention_scores,
        last_attended_timesteps,
        text_lens,
        lookahead_window_size,
        attended_timestep_counter,
        batch_size,
    ):
        """
        Returns the most attended timestep for each batch item
        """
        text_time_step_attended = []
        for bidx in range(batch_size):
            last_attended_timestep = last_attended_timesteps[-1][bidx]
            if attended_timestep_counter[bidx].get(last_attended_timestep, 0) >= 8:
                # This is probably an attention sink! Move to the next timestep
                last_attended_timestep += 1
            window_size = lookahead_window_size
            window_end = min(last_attended_timestep + window_size, text_lens[bidx] - 3)  # Ignore the last 3 timesteps
            item_attention_scores = alignment_attention_scores[bidx, last_attended_timestep:window_end]
            if item_attention_scores.size(0) == 0:
                # This means the sentence has ended
                attended_timestep = text_lens[bidx].item() - 1
            else:
                attended_timestep = item_attention_scores.argmax().item() + last_attended_timestep
            text_time_step_attended.append(attended_timestep)
            attended_timestep_counter[bidx][attended_timestep] = (
                attended_timestep_counter[bidx].get(attended_timestep, 0) + 1
            )
        return text_time_step_attended, attended_timestep_counter

    def construct_inference_prior(
        self,
        prior_epsilon,
        cross_attention_scores,
        text_lens,
        text_time_step_attended,
        attended_timestep_counter,
        unfinished_texts,
        finished_texts_counter,
        end_indices,
        lookahead_window_size,
        batch_size,
    ):
        # Attn prior for the next timestep
        _attn_prior = torch.zeros(cross_attention_scores.shape[0], 1, cross_attention_scores.shape[1]) + prior_epsilon
        _attn_prior = _attn_prior.to(cross_attention_scores.device)
        for bidx in range(cross_attention_scores.shape[0]):
            if bidx < batch_size:
                _text_len = text_lens[bidx]
                if text_lens[bidx] <= 5:
                    # Very short sentences, No Prior
                    _attn_prior[bidx, 0, :] = 1.0
                else:
                    _attn_prior[bidx, 0, max(1, text_time_step_attended[bidx] - 1)] = (
                        1.0  # Slight exposure to history for better pronounciation. Not very important.
                    )
                    _attn_prior[bidx, 0, text_time_step_attended[bidx]] = (
                        1.0  # Slightly bias to continue moving forward. Not very important.
                    )
                    for ind in range(1, lookahead_window_size + 1):
                        _attn_prior[bidx, 0, min(text_time_step_attended[bidx] + ind, _text_len - 1)] = 1.0

                # Penalize timesteps that have been attended to more than 10 times
                for _timestep in attended_timestep_counter[bidx]:
                    if attended_timestep_counter[bidx][_timestep] >= 10:
                        # This means the timestep has been attended to more than 10 times (To avoid getting stuck)
                        _attn_prior[bidx, 0, : _timestep + 1] = prior_epsilon

                unfinished_texts[bidx] = False
                if text_time_step_attended[bidx] < text_lens[bidx] - 3:
                    # This means the sentence has not ended
                    if bidx not in end_indices:
                        unfinished_texts[bidx] = True

                if text_time_step_attended[bidx] >= text_lens[bidx] - 2 or bidx in end_indices:
                    if bidx not in finished_texts_counter:
                        finished_texts_counter[bidx] = 0

        for bidx in finished_texts_counter:
            finished_texts_counter[bidx] += 1
            if finished_texts_counter[bidx] > 5:
                # This means we have been within the text EOS window for at least 5 timesteps
                # We should allow EOS to be predicted now.
                unfinished_texts[bidx] = False

        return _attn_prior, unfinished_texts, finished_texts_counter

    def get_inference_attention_plots(
        self,
        cross_attention_scores_all_timesteps,
        all_heads_cross_attn_scores_all_timesteps,
        text_lens,
        predicted_codes_lens,
        batch_size,
        compute_all_heads_attn_maps,
        last_attended_timestep,
    ):
        last_attended_timestep = np.array(last_attended_timestep).T
        cross_attention_scores_all_timesteps = torch.stack(
            cross_attention_scores_all_timesteps, dim=2
        )  # B, text_timesteps, T'
        headwise_cross_attention_scores_all_timesteps = []
        for hidx in range(len(all_heads_cross_attn_scores_all_timesteps[0])):
            head_cross_attention_all_timesteps = torch.stack(
                [x[hidx] for x in all_heads_cross_attn_scores_all_timesteps], dim=2
            )  # B, text_timesteps, T'
            headwise_cross_attention_scores_all_timesteps.append(head_cross_attention_all_timesteps)

        cross_attention_maps = []
        headwise_cross_attention_maps = []
        for bidx in range(batch_size):
            item_cross_attention_scores = cross_attention_scores_all_timesteps[
                bidx, : text_lens[bidx], : predicted_codes_lens[bidx]
            ]
            cross_attn_np = plot_alignment_to_numpy(
                item_cross_attention_scores.cpu().numpy(),
                attended=last_attended_timestep[bidx, : predicted_codes_lens[bidx]],
            )
            cross_attention_maps.append(cross_attn_np)
            item_all_head_cross_attn_maps = []
            if compute_all_heads_attn_maps:
                for hidx in range(len(all_heads_cross_attn_scores_all_timesteps[0])):
                    item_headwise_cross_attention_scores = headwise_cross_attention_scores_all_timesteps[hidx][
                        bidx, : text_lens[bidx], : predicted_codes_lens[bidx]
                    ]
                    headwise_cross_attn_np = plot_alignment_to_numpy(
                        item_headwise_cross_attention_scores.cpu().numpy(),
                        attended=last_attended_timestep[bidx, : predicted_codes_lens[bidx]],
                    )
                    item_all_head_cross_attn_maps.append(headwise_cross_attn_np)
                headwise_cross_attention_maps.append(item_all_head_cross_attn_maps)

        return cross_attention_maps, headwise_cross_attention_maps

    def find_eos_frame_index(self, codes, eos_detection_method) -> Union[int, float]:
        """
        Checks for EOS in the predicted codes. Returns the index of the first frame within the frame stack
        that contains an EOS token across any codebook, or `None` if no EOS is found.
        Args:
            codes: (num_codebooks, frame_stacking_factor)
        Returns:
            index (within the frame stack) of the first frame with EOS, or `float('inf')` if no EOS is found
        """
        eos_mask = codes == self.audio_eos_id  # (codebooks, frame_stacking_factor)
        detection_type = EOSDetectionMethod.detection_type(eos_detection_method)
        if detection_type == "any":
            eos_per_frame = eos_mask.any(
                dim=0
            )  # (frame_stacking_factor,) - True if any codebook has EOS in this frame
        elif detection_type == "all":
            eos_per_frame = eos_mask.all(
                dim=0
            )  # (frame_stacking_factor,) - True if all codebooks have EOS in this frame
        elif detection_type == "zero_cb":
            eos_per_frame = eos_mask[:1, :].any(
                dim=0
            )  # (frame_stacking_factor,) - True if zeroth codebook has EOS in this frame
        else:
            raise ValueError(f"Invalid EOS detection method: {eos_detection_method}")
        # find first frame with EOS
        if eos_per_frame.any():
            # return index of the first frame with EOS
            return eos_per_frame.nonzero()[0].item()
        return float('inf')

    def detect_eos(self, audio_codes_multinomial, audio_codes_argmax, eos_detection_method) -> Union[int, float]:
        """
        Detects EOS in the predicted codes. Returns the index of the first frame within the frame stack
        that triggers EOS detection, or `float('inf')` if no EOS is found.
        Args:
            audio_codes_multinomial: (num_codebooks, frame_stacking_factor) - Multinomial samples
            audio_codes_argmax: (num_codebooks, frame_stacking_factor) - Argmax samples
            eos_detection_method: EOS detection method
        Returns:
            index (within the frame stack) of the first frame with EOS, or `float('inf')` if no EOS is found
        """
        sampling_type = EOSDetectionMethod.sampling_type(eos_detection_method)
        if sampling_type == "argmax":
            return self.find_eos_frame_index(audio_codes_argmax, eos_detection_method)
        elif sampling_type == "argmax_or_multinomial":
            argmax_eos_frame = self.find_eos_frame_index(audio_codes_argmax, eos_detection_method)
            multinomial_eos_frame = self.find_eos_frame_index(audio_codes_multinomial, eos_detection_method)
            return min(argmax_eos_frame, multinomial_eos_frame)
        else:
            raise ValueError(f"Invalid EOS detection method: {eos_detection_method}")

    def infer_batch(
        self,
        batch,
        max_decoder_steps=500,
        temperature=0.7,
        topk=80,
        use_cfg=False,
        cfg_scale=1.0,
        return_cross_attn_probs=False,
        apply_attention_prior=False,
        prior_epsilon=1e-5,
        lookahead_window_size=10,
        estimate_alignment_from_layers=None,
        apply_prior_to_layers=None,
        start_prior_after_n_audio_steps=10,
        compute_all_heads_attn_maps=False,
        use_local_transformer_for_inference=False,
        use_LT_kv_cache=True,
        maskgit_n_steps=3,
        maskgit_noise_scale=0.0,
        maskgit_fixed_schedule=None,
        maskgit_dynamic_cfg_scale=False,
        maskgit_sampling_type=None,
        ignore_finished_sentence_tracking=False,
        eos_detection_method="argmax_or_multinomial_any",
        # Setting this greater than 0 prevents rare cases of first-frame termination. Any number greater between 1 and 4 should work, but 4
        # lines up with the codec's minimum frame requirement.
        min_generated_frames=4,
    ):
        eos_detection_method = EOSDetectionMethod(eos_detection_method)
        with torch.no_grad():
            start_time = time.time()
            self.decoder.reset_cache(use_cache=self.use_kv_cache_for_inference)

            context_tensors = self.prepare_context_tensors(batch)
            text = context_tensors['text']
            audio_codes_bos = torch.full(
                (text.size(0), self.num_audio_codebooks, self.frame_stacking_factor),
                self.audio_bos_id,
                device=text.device,
            ).long()
            audio_codes_lens = torch.full(
                (text.size(0),), 1, device=text.device
            ).long()  # intetionally 1 rather than self.frame_stacking_factor since this is in stacked form
            audio_codes_input = audio_codes_bos
            audio_codes_mask = get_mask_from_lengths(audio_codes_lens)

            all_predictions = []
            end_indices = {}

            if use_cfg:
                dummy_cond, dummy_cond_mask, dummy_additional_decoder_input, dummy_addition_dec_mask, _ = (
                    self.prepare_dummy_cond_for_cfg(
                        context_tensors['cond'],
                        context_tensors['cond_mask'],
                        context_tensors['additional_decoder_input'],
                        context_tensors['additional_decoder_mask'],
                    )
                )

            cross_attention_scores_all_timesteps = []
            all_heads_cross_attn_scores_all_timesteps = []
            _attn_prior = None
            unfinished_texts = {}
            finished_texts_counter = {}
            attended_timestep_counter = [{} for _ in range(text.size(0))]
            last_attended_timesteps = [
                [1 for _ in range(text.size(0))]
            ]  # Maintain a list of attended timesteps as we predict audio for each batch item
            time_to_first_prediction = 0.0
            for idx in range(max_decoder_steps // self.frame_stacking_factor):
                if idx == 1:
                    time_to_first_prediction = time.time() - start_time
                if idx % 20 == 0:
                    print(f"Decoding timestep {idx}")
                audio_codes_embedded = self.embed_audio_tokens(audio_codes_input)
                if context_tensors['additional_decoder_input'] is not None:
                    _audio_codes_embedded = torch.cat(
                        [context_tensors['additional_decoder_input'], audio_codes_embedded], dim=1
                    )
                    _audio_codes_mask = torch.cat(
                        [context_tensors['additional_decoder_mask'], audio_codes_mask], dim=1
                    )
                else:
                    _audio_codes_embedded = audio_codes_embedded
                    _audio_codes_mask = audio_codes_mask

                if apply_prior_to_layers is not None:
                    attn_prior = [None for _ in range(self.decoder.n_layers)]
                    for layer_idx in apply_prior_to_layers:
                        attn_prior[layer_idx] = _attn_prior
                else:
                    attn_prior = _attn_prior

                if self.model_type == 'multi_encoder_context_tts':
                    attn_prior = [attn_prior, None]

                if use_cfg:
                    batch_size = audio_codes_embedded.size(0)
                    if isinstance(context_tensors['cond'], list):
                        cfg_cond = [
                            torch.cat([cond_item, dummy_cond_item], dim=0)
                            for cond_item, dummy_cond_item in zip(context_tensors['cond'], dummy_cond)
                        ]
                        cfg_cond_mask = [
                            torch.cat([cond_mask_item, dummy_cond_mask_item], dim=0)
                            for cond_mask_item, dummy_cond_mask_item in zip(
                                context_tensors['cond_mask'], dummy_cond_mask
                            )
                        ]
                    else:
                        cfg_cond = torch.cat([context_tensors['cond'], dummy_cond], dim=0)
                        cfg_cond_mask = torch.cat([context_tensors['cond_mask'], dummy_cond_mask], dim=0)
                    cfg_audio_codes_embedded = torch.cat([_audio_codes_embedded, _audio_codes_embedded], dim=0)
                    cfg_audio_codes_mask = torch.cat([_audio_codes_mask, _audio_codes_mask], dim=0)
                    if dummy_additional_decoder_input is not None:
                        cfg_audio_codes_embedded[batch_size:, : dummy_additional_decoder_input.size(1)] = (
                            dummy_additional_decoder_input
                        )
                        cfg_audio_codes_mask[batch_size:, : dummy_additional_decoder_input.size(1)] = (
                            dummy_addition_dec_mask
                        )

                    # print(f"step {idx}")
                    # print(f"use_cfg {use_cfg}")
                    # print(f"shape {cfg_audio_codes_embedded.shape}")
                    # print(f"use kv cahce? {self.use_kv_cache_for_inference}")
                    combined_logits, attn_probs, dec_out = self.forward(
                        dec_input_embedded=cfg_audio_codes_embedded,
                        dec_input_mask=cfg_audio_codes_mask,
                        cond=cfg_cond,
                        cond_mask=cfg_cond_mask,
                        attn_prior=attn_prior,
                        multi_encoder_mapping=context_tensors['multi_encoder_mapping'],
                    )

                    cond_logits = combined_logits[:batch_size]
                    uncond_logits = combined_logits[batch_size:]
                    all_code_logits = (1 - cfg_scale) * uncond_logits + cfg_scale * cond_logits
                else:
                    batch_size = audio_codes_embedded.size(0)
                    all_code_logits, attn_probs, dec_out = self.forward(
                        dec_input_embedded=_audio_codes_embedded,
                        dec_input_mask=_audio_codes_mask,
                        cond=context_tensors['cond'],
                        cond_mask=context_tensors['cond_mask'],
                        attn_prior=attn_prior,
                        multi_encoder_mapping=context_tensors['multi_encoder_mapping'],
                    )

                if return_cross_attn_probs or apply_attention_prior:
                    cross_attention_scores, all_heads_cross_attn_scores = self.get_cross_attention_scores(
                        attn_probs
                    )  # B, text_timesteps
                    alignment_attention_scores = cross_attention_scores
                    if estimate_alignment_from_layers is not None:
                        alignment_attention_scores, _ = self.get_cross_attention_scores(
                            attn_probs, filter_layers=estimate_alignment_from_layers
                        )  # B, text_timesteps

                    cross_attention_scores_all_timesteps.append(cross_attention_scores)
                    all_heads_cross_attn_scores_all_timesteps.append(all_heads_cross_attn_scores)

                if apply_attention_prior and idx >= start_prior_after_n_audio_steps:
                    text_time_step_attended, attended_timestep_counter = self.get_most_attended_text_timestep(
                        alignment_attention_scores=alignment_attention_scores,
                        last_attended_timesteps=last_attended_timesteps,
                        text_lens=context_tensors['text_lens'],
                        lookahead_window_size=lookahead_window_size,
                        attended_timestep_counter=attended_timestep_counter,
                        batch_size=batch_size,
                    )
                    last_attended_timesteps.append(text_time_step_attended)
                    _attn_prior, unfinished_texts, finished_texts_counter = self.construct_inference_prior(
                        prior_epsilon=prior_epsilon,
                        cross_attention_scores=cross_attention_scores,
                        text_lens=context_tensors['text_lens'],
                        text_time_step_attended=text_time_step_attended,
                        attended_timestep_counter=attended_timestep_counter,
                        unfinished_texts=unfinished_texts,
                        finished_texts_counter=finished_texts_counter,
                        end_indices=end_indices,
                        lookahead_window_size=lookahead_window_size,
                        batch_size=batch_size,
                    )

                if ignore_finished_sentence_tracking:
                    finished_items = {}
                    unfinished_items = {}
                else:
                    finished_items = {
                        k: v for k, v in finished_texts_counter.items() if v >= 20
                    }  # Items that have been close to the end for atleast 20 timesteps
                    unfinished_items = {k: v for k, v in unfinished_texts.items() if v}

                # Don't allow termination until we have generated at least `min_generated_frames` frames (rounded up to the nearest multiple of frame_stacking_factor)
                # This guards against rare cases of termination right at the start of generation.
                forbid_audio_eos = idx * self.frame_stacking_factor < min_generated_frames

                all_code_logits_t = all_code_logits[:, -1, :]  # (B, num_codebooks * num_tokens_per_codebook)
                if use_local_transformer_for_inference:
                    if self.local_transformer_type == LocalTransformerType.AR:
                        # Autoregressive sampling with local transformer
                        audio_codes_next = self.local_transformer_sample_autoregressive(
                            dec_output=dec_out[:, -1, :],
                            temperature=temperature,
                            topk=topk,
                            unfinished_items=unfinished_items,
                            finished_items=finished_items,
                            use_cfg=use_cfg,
                            cfg_scale=cfg_scale,
                            use_kv_cache=use_LT_kv_cache,
                            forbid_audio_eos=forbid_audio_eos,
                        )
                    elif self.local_transformer_type == LocalTransformerType.MASKGIT:
                        audio_codes_next = self.local_transformer_sample_maskgit(
                            dec_output=dec_out[:, -1, :],
                            temperature=temperature,
                            topk=topk,
                            unfinished_items=unfinished_items,
                            finished_items=finished_items,
                            use_cfg=use_cfg,
                            cfg_scale=cfg_scale,
                            n_steps=maskgit_n_steps,
                            noise_scale=maskgit_noise_scale,
                            fixed_schedule=maskgit_fixed_schedule,
                            dynamic_cfg_scale=maskgit_dynamic_cfg_scale,
                            sampling_type=maskgit_sampling_type,
                            forbid_audio_eos=forbid_audio_eos,
                        )
                    else:
                        raise ValueError(
                            f"Local transformer inference requested by but local transformer type is {self.local_transformer_type}"
                        )
                else:
                    # Parallel sampling from all codebooks
                    audio_codes_next = self.sample_codes_from_logits(
                        all_code_logits_t,
                        temperature=temperature,
                        topk=topk,
                        unfinished_items=unfinished_items,
                        finished_items=finished_items,
                        forbid_audio_eos=forbid_audio_eos,
                    )  # (B, num_codebooks, frame_stacking_factor)
                all_codes_next_argmax = self.sample_codes_from_logits(
                    all_code_logits_t,
                    temperature=0.01,
                    topk=1,
                    unfinished_items=unfinished_items,
                    finished_items=finished_items,
                    forbid_audio_eos=forbid_audio_eos,
                )  # (B, num_codebooks, frame_stacking_factor)

                for item_idx in range(all_codes_next_argmax.size(0)):
                    if item_idx not in end_indices:
                        end_frame_index = self.detect_eos(
                            audio_codes_next[item_idx], all_codes_next_argmax[item_idx], eos_detection_method
                        )
                        if end_frame_index != float('inf'):
                            global_index = idx * self.frame_stacking_factor + end_frame_index
                            end_indices[item_idx] = global_index
                            print(f"End detected for item {item_idx} at decoder timestep: {idx}")

                all_predictions.append(audio_codes_next)
                audio_codes_input = torch.cat([audio_codes_input, audio_codes_next], dim=-1)  # (B, C, T')
                audio_codes_lens = audio_codes_lens + 1  # already in stacked form
                audio_codes_mask = get_mask_from_lengths(audio_codes_lens)
                if len(end_indices) == text.size(0) and len(all_predictions) >= 4:
                    # Codec must be of atleast 4 timesteps to be decoded properly
                    print("All ends reached")
                    break
            tts_generation_time = time.time() - start_time
            tts_generation_time_per_frame = tts_generation_time / (len(all_predictions) * self.frame_stacking_factor)

            # Concatenate the list of predictions along the time dimension. Note that when frame stacking is on,
            # this also undoes the stacking.
            predicted_codes = torch.cat(all_predictions, dim=-1)  # (B, num_codebooks, T')
            predicted_lens = [
                end_indices.get(idx, max_decoder_steps) for idx in range(text.size(0))
            ]  #  Ensure that the codec is atleast of length 4
            predicted_codes_lens = torch.tensor(predicted_lens, device=text.device).long()

            predicted_audio, predicted_audio_lens = self.codes_to_audio(predicted_codes, predicted_codes_lens)
            end_time = time.time()
            total_audio_duration_generated = (
                predicted_audio_lens.max().item() * predicted_audio_lens.shape[0]
            ) / self.sample_rate
            rtf = total_audio_duration_generated / (end_time - start_time)
            rtf_metrics = {
                'rtf': rtf,
                'time_to_first_prediction': time_to_first_prediction,
                'tts_generation_time': tts_generation_time,
                'max_frames_generated': len(all_predictions),
                'tts_generation_time_per_frame': tts_generation_time_per_frame,
                'batch_size': text.size(0),
            }
            torch.cuda.empty_cache()
            cross_attention_maps = None
            headwise_cross_attention_maps = None
            if return_cross_attn_probs:
                cross_attention_maps, headwise_cross_attention_maps = self.get_inference_attention_plots(
                    cross_attention_scores_all_timesteps,
                    all_heads_cross_attn_scores_all_timesteps,
                    context_tensors['text_lens'],
                    predicted_codes_lens,
                    text.size(0),
                    compute_all_heads_attn_maps,
                    last_attended_timesteps,
                )

            return InferBatchOutput(
                predicted_audio=predicted_audio,
                predicted_audio_lens=predicted_audio_lens,
                predicted_codes=predicted_codes,
                predicted_codes_lens=predicted_codes_lens,
                rtf_metrics=rtf_metrics,
                cross_attention_maps=cross_attention_maps,
                headwise_cross_attention_maps=headwise_cross_attention_maps,
            )

    def test_step(self, batch, batch_idx):
        with torch.no_grad():
            test_dl_batch_size = self._test_dl.batch_size
            temperature = self.cfg.get('inference_temperature', 0.7)
            topk = self.cfg.get('inference_topk', 80)
            use_cfg = self.cfg.get('inference_use_cfg', False)
            cfg_scale = self.cfg.get('inference_cfg_scale', 1.0)
            output = self.infer_batch(
                batch,
                max_decoder_steps=self.cfg.get('max_decoder_steps', 500),
                temperature=temperature,
                topk=topk,
                use_cfg=use_cfg,
                cfg_scale=cfg_scale,
            )
            predicted_audio = output.predicted_audio
            predicted_audio_lens = output.predicted_audio_lens

            for logger in self.loggers:
                is_wandb = isinstance(logger, WandbLogger)
                is_tb = isinstance(logger, TensorBoardLogger)
                if not is_wandb and not is_tb:
                    raise ValueError(
                        "Invalid logger type for audio logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported."
                    )

                for idx in range(predicted_audio.size(0)):
                    predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy()
                    predicted_audio_np = predicted_audio_np[: predicted_audio_lens[idx]]
                    item_idx = batch_idx * test_dl_batch_size + idx

                    if is_wandb:
                        log_dict = {
                            "test/predicted_audio": wandb.Audio(
                                predicted_audio_np, sample_rate=self.sample_rate, caption="Predicted Audio"
                            ),
                        }
                        logger.experiment.log(log_dict, step=item_idx)

                    if is_tb:
                        logger.experiment.add_audio(
                            'test/predicted_audio',
                            predicted_audio_np,
                            global_step=item_idx,
                            sample_rate=self.sample_rate,
                        )

                    # Save the predicted audio
                    log_dir = logger.log_dir
                    audio_dir = os.path.join(log_dir, 'audios')
                    if not os.path.exists(audio_dir):
                        os.makedirs(audio_dir)
                    audio_path = os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}.wav')
                    sf.write(audio_path, predicted_audio_np, self.sample_rate)

    def on_validation_epoch_end(self):
        collect = lambda key: torch.stack([x[key] for x in self.validation_step_outputs]).mean()
        val_loss = collect("val_loss")
        val_codebook_loss = collect("val_codebook_loss")
        val_alignment_loss = collect("val_alignment_loss")
        val_aligner_encoder_loss = collect("val_aligner_encoder_loss")
        # log val_loss in the same group as the other val metrics.
        self.log("val/loss", val_loss, prog_bar=True, sync_dist=True)
        # ensure val_loss is available for epoch-level checkpointing and filename generation without cluttering wandb logs.
        self.log(
            "val_loss",
            val_loss,
            prog_bar=False,
            sync_dist=True,
            on_step=False,
            on_epoch=True,
            logger=False,
            enable_graph=False,
        )
        self.log("val/codebook_loss", val_codebook_loss, prog_bar=True, sync_dist=True)
        self.log("val/alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True)
        self.log("val/aligner_encoder_loss", val_aligner_encoder_loss, prog_bar=True, sync_dist=True)
        if self.local_transformer_type != LocalTransformerType.NO_LT:
            val_local_transformer_loss = collect("val_local_transformer_loss")
            self.log("val/local_transformer_loss", val_local_transformer_loss, prog_bar=True, sync_dist=True)
        self.validation_step_outputs.clear()  # free memory

    def get_dataset(self, dataset_cfg, dataset_type):
        dataset = instantiate(
            dataset_cfg.dataset,
            sample_rate=self.sample_rate,
            bos_id=self.bos_id,
            eos_id=self.eos_id,
            audio_bos_id=self.audio_bos_id,
            audio_eos_id=self.audio_eos_id,
            context_audio_bos_id=self.context_audio_bos_id,
            context_audio_eos_id=self.context_audio_eos_id,
            num_audio_codebooks=self.data_num_audio_codebooks,
            codec_model_samples_per_frame=self.codec_model_samples_per_frame,
            prior_scaling_factor=self.cfg.prior_scaling_factor,
            load_cached_codes_if_available=self.cfg.load_cached_codes_if_available,
            dataset_type=dataset_type,  # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn)
            use_text_conditioning_tokenizer=self.cfg.use_text_conditioning_encoder,
            text_conditioning_tokenizer_name=self.text_conditioning_tokenizer_name,
            pad_context_text_to_max_duration=self.pad_context_text_to_max_duration,
            context_duration_min=self.cfg.context_duration_min,
            context_duration_max=self.cfg.context_duration_max,
            text_context_remapping=self.text_context_remapping,
            text_context_remapping_prob=self.text_context_remapping_prob,
        )
        dataset.load_16khz_audio = False
        dataset.tokenizer_config = (
            self.cfg.text_tokenizers
        )  # This will be used in worker_init_fn for instantiating tokenizer
        return dataset

    def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.DataLoader:
        # TODO @xueyang: better to distinguish cfg. self.cfg is the model cfg, while cfg here is train_ds cfg. Also
        #   cfg is a classifier-free guidance.
        dataset = MagpieTTSLhotseDataset(
            sample_rate=self.sample_rate,
            volume_norm=dataset_cfg.volume_norm,
            codec_model_samples_per_frame=self.codec_model_samples_per_frame,
            audio_bos_id=self.audio_bos_id,
            audio_eos_id=self.audio_eos_id,
            context_audio_bos_id=self.context_audio_bos_id,
            context_audio_eos_id=self.context_audio_eos_id,
            num_audio_codebooks=self.data_num_audio_codebooks,
            prior_scaling_factor=self.cfg.prior_scaling_factor,
            load_cached_codes_if_available=self.cfg.load_cached_codes_if_available,
            dataset_type=mode,  # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn)
            load_16khz_audio=False,
            pad_context_text_to_max_duration=self.pad_context_text_to_max_duration,
            context_duration_min=self.cfg.context_duration_min,
            context_duration_max=self.cfg.context_duration_max,
            use_text_conditioning_tokenizer=self.cfg.use_text_conditioning_encoder,
            text_conditioning_tokenizer_name=self.text_conditioning_tokenizer_name,
            tokenizer_config=self.cfg.text_tokenizers,
            text_context_remapping=self.text_context_remapping,
            text_context_remapping_prob=self.text_context_remapping_prob,
        )
        data_loader = get_lhotse_dataloader_from_config(
            config=dataset_cfg.dataset,
            global_rank=self.global_rank,
            world_size=self.world_size,
            dataset=dataset,
        )
        return data_loader

    def setup_training_data(self, dataset_cfg):
        if dataset_cfg.get("use_lhotse", False):
            # TODO @xueyang: better to distinguish cfg. self.cfg is the model cfg, while cfg here is train_ds cfg. Also
            #   cfg is a classifier-free guidance.

            # specify target sampling rate the same as codec model's because lhotse config defaults 16_000.
            if not isinstance(dataset_cfg, DictConfig):
                dataset_cfg = OmegaConf.create(dataset_cfg)
            OmegaConf.set_struct(dataset_cfg.dataset, False)
            dataset_cfg.dataset.update({"sample_rate": self.sample_rate})
            OmegaConf.set_struct(dataset_cfg.dataset, True)

            self._train_dl = self.get_lhotse_dataloader(dataset_cfg, mode='train')
        else:
            dataset = self.get_dataset(dataset_cfg, dataset_type='train')
            sampler = dataset.get_sampler(dataset_cfg.dataloader_params.batch_size, world_size=self.trainer.world_size)
            persistent_workers = True
            if dataset_cfg.dataloader_params.num_workers == 0:
                persistent_workers = False
                # For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable)
                dataset.text_tokenizer = setup_tokenizers(
                    all_tokenizers_config=self.cfg.text_tokenizers,
                    mode='train',
                )
            self._train_dl = torch.utils.data.DataLoader(
                dataset,
                collate_fn=dataset.collate_fn,
                sampler=sampler,
                **dataset_cfg.dataloader_params,
                worker_init_fn=worker_init_fn,
                persistent_workers=persistent_workers,
            )

    def _setup_test_dataloader(self, dataset_cfg) -> torch.utils.data.DataLoader:
        if dataset_cfg.get("use_lhotse", False):
            # specify target sampling rate the same as codec model's because lhotse config defaults 16_000.
            if not isinstance(dataset_cfg, DictConfig):
                dataset_cfg = OmegaConf.create(dataset_cfg)
            OmegaConf.set_struct(dataset_cfg.dataset, False)
            dataset_cfg.dataset.update({"sample_rate": self.sample_rate})
            OmegaConf.set_struct(dataset_cfg.dataset, True)
            data_loader = self.get_lhotse_dataloader(dataset_cfg, mode='test')
        else:
            dataset = self.get_dataset(dataset_cfg, dataset_type='test')
            persistent_workers = True
            if dataset_cfg.dataloader_params.num_workers == 0:
                persistent_workers = False
                # For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable)
                dataset.text_tokenizer = setup_tokenizers(all_tokenizers_config=self.cfg.text_tokenizers, mode='test')

            data_loader = torch.utils.data.DataLoader(
                dataset,
                collate_fn=dataset.collate_fn,
                **dataset_cfg.dataloader_params,
                worker_init_fn=worker_init_fn,
                persistent_workers=persistent_workers,
            )
        return data_loader

    def setup_validation_data(self, dataset_cfg):
        self._validation_dl = self._setup_test_dataloader(dataset_cfg)

    def setup_test_data(self, dataset_cfg):
        self._test_dl = self._setup_test_dataloader(dataset_cfg)

    def setup_dummy_text_context_in_batch(
        self,
        batch: Dict[str, torch.Tensor],
    ) -> bool:
        """Setup dummy text context tensors in the batch dictionary.
        """
        # No text context provided - set up dummy if model requires text conditioning tensors
        dummy_context_text = "[NO TEXT CONTEXT]"
        dummy_tokens = self.tokenizer.encode(
            text=dummy_context_text, tokenizer_name=self.text_conditioning_tokenizer_name
        )
        batch['context_text_tokens'] = torch.tensor([dummy_tokens], device=self.device, dtype=torch.long)
        batch['context_text_tokens_lens'] = torch.tensor([len(dummy_tokens)], device=self.device, dtype=torch.long)
        batch['has_text_context'] = torch.tensor([False], device=self.device, dtype=torch.bool)


    def setup_dummy_audio_context_in_batch(
        self,
        batch: Dict[str, torch.Tensor],
        context_audio: Optional[torch.Tensor] = None,
        context_audio_lens: Optional[torch.Tensor] = None,
    ) -> bool:
        """Setup dummy audio context tensors in the batch dictionary.
        """
        # Model has baked context - create minimal dummy context tensors
        # These will be ignored in prepare_context_tensors when baked embedding is used
        dummy_context_codes = torch.zeros(
            1, self.num_audio_codebooks, 2, device=self.device, dtype=torch.long
        )
        dummy_context_codes[:, :, 0] = self.context_audio_bos_id
        dummy_context_codes[:, :, 1] = self.context_audio_eos_id
        batch['context_audio_codes'] = dummy_context_codes
        batch['context_audio_codes_lens'] = torch.tensor([2], device=self.device, dtype=torch.long)

    def do_tts(
        self,
        transcript: str,
        language: str = "en",
        apply_TN: bool = False,
        temperature: float = 0.7,
        topk: int = 80,
        max_decoder_steps: int = 500,
        use_cfg: bool = True,
        cfg_scale: float = 2.5,
    ) -> tuple:
        """
        Generate speech from raw text transcript.

        This is a convenience method for single-utterance text-to-speech synthesis.
        For batch processing, use `infer_batch` directly. Only supports baked context embedding
        context injection, NO audio conditioning and text conditioning.
        Custom voice generation is not supported by this method.

        Args:
            transcript: Raw text to synthesize.
            language: Language code for text normalization and tokenization.
                Supported values depend on model's tokenizer configuration.
                Common: "en" (English), "de" (German), "es" (Spanish), etc.
            apply_TN: Whether to apply text normalization to the transcript.
                If True, uses nemo_text_processing for normalization.
            temperature: Sampling temperature for token generation.
            topk: Top-k sampling parameter.
            max_decoder_steps: Maximum number of decoder steps.
            use_cfg: Whether to use classifier-free guidance.
            cfg_scale: Scale factor for classifier-free guidance.

        Returns:
            Tuple of (audio, audio_len) where:
                audio: Generated audio waveform. Shape: (1, T_audio).
                audio_len: Length of generated audio in samples. Shape: (1,).

        Raises:
            ValueError: If model does not have a baked context embedding.
            ImportError: If apply_TN=True but nemo_text_processing is not installed.

        Example:
            >>> # If text does not need to be normalized
            >>> audio, audio_len = model.do_tts("Hello, how are you today?")
            >>>
            >>> # If text needs to be normalized
            >>> audio, audio_len = model.do_tts(
            ...     "Hello, how are you today?",
            ...     apply_TN=True,
            ... )
        """
        assert self.has_baked_context_embedding, "Model does not have a baked context embedding. Please use a checkpoint with a baked context embedding."
        # Apply text normalization if requested
        normalized_text = transcript
        if apply_TN:
            try:
                from nemo_text_processing.text_normalization.normalize import Normalizer

                normalizer = Normalizer(input_case='cased', lang=language)
                normalized_text = normalizer.normalize(transcript, verbose=False)
                logging.debug(f"Text normalization: '{transcript}' -> '{normalized_text}'")
            except ImportError:
                logging.warning(
                    "nemo_text_processing not installed. Skipping text normalization. "
                    "Install with: pip install nemo_text_processing"
                )

        # Determine tokenizer name based on language
        # Try to find a matching tokenizer, fallback to first available
        tokenizer_name = None
        available_tokenizers = list(self.tokenizer.tokenizers.keys())
        print(f"Available tokenizers: {available_tokenizers}")

        # Common mappings for tokenizer names
        language_tokenizer_map = {
            "en": ["english_phoneme", "english"],
            "de": ["german_phoneme", "german"],
            "es": ["spanish_phoneme", "spanish"],
            "fr": ["french_phoneme", "french"],
            "it": ["italian_phoneme", "italian"],
            "vi": ["vietnamese_phoneme", "vietnamese"],
            "zh": ["mandarin_phoneme", "mandarin", "chinese"],
        }

        # Find matching tokenizer
        if language in language_tokenizer_map:
            for candidate in language_tokenizer_map[language]:
                if candidate in available_tokenizers:
                    tokenizer_name = candidate
                    break

        # Fallback to first available tokenizer
        if tokenizer_name is None:
            tokenizer_name = available_tokenizers[0]
            logging.info(
                f"No tokenizer found for language '{language}'. "
                f"Using '{tokenizer_name}'. Available: {available_tokenizers}"
            )

        # Tokenize the transcript text
        tokens = self.tokenizer.encode(text=normalized_text, tokenizer_name=tokenizer_name)
        tokens = tokens + [self.eos_id]  # Add EOS token (BOS not used per dataset convention)
        text_tensor = torch.tensor([tokens], device=self.device, dtype=torch.long)
        text_lens = torch.tensor([len(tokens)], device=self.device, dtype=torch.long)

        # Create batch dictionary
        batch = {
            'text': text_tensor,
            'text_lens': text_lens,
        }

        # Setup context in batch
        if self.use_text_conditioning_encoder:
            self.setup_dummy_text_context_in_batch(batch)
        self.setup_dummy_audio_context_in_batch(batch)

        # Run inference
        with torch.no_grad():
            output = self.infer_batch(
                batch,
                max_decoder_steps=max_decoder_steps,
                temperature=temperature,
                topk=topk,
                use_cfg=use_cfg,
                cfg_scale=cfg_scale,
            )

        return output.predicted_audio, output.predicted_audio_lens

    @classmethod
    def list_available_models(cls) -> List[PretrainedModelInfo]:
        return []