erl-j commited on
Commit
5661713
·
verified ·
1 Parent(s): a3a4c54

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ checkpoint/trainer_state.json filter=lfs diff=lfs merge=lfs -text
checkpoint/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Phi3ForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 0,
7
+ "embd_pdrop": 0.0,
8
+ "eos_token_id": 1,
9
+ "hidden_act": "silu",
10
+ "hidden_size": 512,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 2048,
13
+ "max_position_embeddings": 4096,
14
+ "model_type": "phi3",
15
+ "num_attention_heads": 8,
16
+ "num_hidden_layers": 6,
17
+ "num_key_value_heads": 8,
18
+ "original_max_position_embeddings": 4096,
19
+ "pad_token_id": 3,
20
+ "partial_rotary_factor": 1.0,
21
+ "resid_pdrop": 0.0,
22
+ "rms_norm_eps": 1e-05,
23
+ "rope_scaling": null,
24
+ "rope_theta": 10000.0,
25
+ "sliding_window": null,
26
+ "tie_word_embeddings": false,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.49.0",
29
+ "use_cache": true,
30
+ "vocab_size": 618
31
+ }
checkpoint/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 3,
6
+ "transformers_version": "4.49.0"
7
+ }
checkpoint/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db08ddc79689f3fab5991a4546766c3017849ea0edfd15edd6d3b3d5581e058d
3
+ size 103225512
checkpoint/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:749e051ffbac8c9bd819313906d22632591f8a93c5a31329d84a28f3ef36924e
3
+ size 206475706
checkpoint/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c94c375fe5ad2903d244ca6b5cc2a1a6cba4c0c26196f3b9cbd9ddd170bb0b8
3
+ size 14244
checkpoint/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11d424d5873a8cd68622baa36c10eafa92a53e48e582a3cec4682af6668d5418
3
+ size 1064
checkpoint/trainer_state.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90dee1a09be4e65a502af61c302567107374478c1ad62190c05150e482b28c4f
3
+ size 15033010
checkpoint/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4933ed958088f5852591052bd85276111256d2e4724b7428775dfa8298a37a87
3
+ size 5368
tokenisation.py ADDED
@@ -0,0 +1,779 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import symusic
2
+ import pretty_midi
3
+ import numpy as np
4
+ from dataclasses import dataclass, asdict
5
+ from typing import List, Tuple, Dict, TypeVar, Generic, Type
6
+ import json
7
+ import random
8
+ import logging
9
+ from util import crop_sm
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Constants
14
+ MIDI_DRUM_PITCHES = range(22, 82)
15
+
16
+
17
+ def crop_sm(sm, n_beats):
18
+ """
19
+ Crop a symbolic music object to a specific number of beats.
20
+
21
+ Parameters:
22
+ -----------
23
+ sm : object
24
+ Symbolic music object with tpq attribute and clip method
25
+ n_beats : int
26
+ Number of beats to keep
27
+
28
+ Returns:
29
+ --------
30
+ object
31
+ Cropped symbolic music object
32
+ """
33
+ # Create a copy to avoid modifying the original
34
+ sm_copy = sm.copy()
35
+ tpq = sm_copy.tpq
36
+
37
+ # first check that the end is not less than n_beats
38
+ if sm_copy.end() > n_beats * tpq:
39
+ # Clip to specified number of beats
40
+ sm_copy = sm_copy.clip(0, n_beats * tpq, clip_end=True)
41
+
42
+ return sm_copy
43
+
44
+
45
+ class Quantizer:
46
+ def __init__(
47
+ self, value_range: Tuple[float, float], n_bins: int, round_values: bool = False
48
+ ):
49
+ self.range = value_range
50
+ self.n_bins = n_bins
51
+ self.bins = np.linspace(value_range[0], value_range[1], n_bins)
52
+ if round_values:
53
+ self.bins = np.round(self.bins).astype(int)
54
+
55
+ # returns float or int
56
+ def quantize(self, value: float):
57
+ """Returns the closest bin value for a given input."""
58
+ return self.bins[np.argmin(np.abs(self.bins - value))]
59
+
60
+ @dataclass
61
+ class TokenizerConfig:
62
+ pass
63
+
64
+ # Create a type variable bound to TokenizerConfig
65
+ T = TypeVar("T", bound=TokenizerConfig)
66
+
67
+
68
+ class BaseTokenizer(Generic[T]):
69
+ """Abstract base class for MIDI tokenizers."""
70
+
71
+ config_cls: Type[T] # Type annotation for the class variable
72
+
73
+ def __init__(self, config: T) -> None:
74
+ self.config = config
75
+ self.vocab: List[str] = []
76
+ self.token_to_idx: Dict[str, int] = {}
77
+ self.pad_token_id = -1
78
+
79
+ def to_json(self, path: str) -> None:
80
+ """Save the tokenizer configuration to a JSON file."""
81
+ with open(path, "w") as f:
82
+ json.dump(self.config.__dict__, f, indent=2)
83
+
84
+ @classmethod
85
+ def from_json(cls, path: str):
86
+ """Load the tokenizer configuration from a JSON file."""
87
+ with open(path, "r") as f:
88
+ config_dict = json.load(f)
89
+ config = cls.config_cls(**config_dict)
90
+ return cls(config)
91
+
92
+ def midi_to_tokens(self, midi: symusic.Score) -> List[str]:
93
+ raise NotImplementedError
94
+
95
+ def tokens_to_midi(self, tokens: List[str]) -> symusic.Score:
96
+ raise NotImplementedError
97
+
98
+ def ids_to_midi(self, ids: List[int]) -> symusic.Score:
99
+ return self.tokens_to_midi(self.ids_to_tokens(ids))
100
+
101
+ def midi_to_ids(self, midi: symusic.Score) -> List[int]:
102
+ return self.tokens_to_ids(self.midi_to_tokens(midi))
103
+
104
+ def tokens_to_ids(self, tokens: List[str]) -> List[int]:
105
+ """Convert tokens to their corresponding indices."""
106
+ return [self.token_to_idx[token] for token in tokens]
107
+
108
+ def ids_to_tokens(self, ids: List[int]) -> List[str]:
109
+ """Convert indices back to tokens."""
110
+ return [self.vocab[idx] for idx in ids]
111
+
112
+
113
+ @dataclass
114
+ class TanjaTokenizerConfig(TokenizerConfig):
115
+ ticks_per_beat: int
116
+ coarse_ticks_per_beat: int
117
+ tempo_range: Tuple[int, int]
118
+ n_tempo_bins: int
119
+ n_velocity_bins: int
120
+ n_bars : int
121
+ n_events : int
122
+
123
+ def dict(self):
124
+ return {k: str(v) for k, v in asdict(self).items()}
125
+
126
+ class TanjaTokenizer(BaseTokenizer):
127
+
128
+ '''
129
+ CMLM Tokenizer.
130
+ This tokenizer outputs a list of tokens in the following format:
131
+ # First tempo is provided.
132
+ Tempo
133
+ # Then for each note we have 7 attributes.
134
+ Program Pitch OnsetCoarse OnsetFine Offset Duration Velocity
135
+ # There is also a mask token
136
+ '''
137
+
138
+ def __init__(self, config: TanjaTokenizerConfig):
139
+
140
+ self.config = config
141
+ self.vocab = []
142
+
143
+ self.n_beats = config.n_bars * 4
144
+
145
+ self.vocab.append("BOS_None")
146
+ self.vocab.append("EOS_None")
147
+ self.vocab.append("SEP_None")
148
+ self.vocab.append("PAD_None")
149
+ self.vocab.append("MASK_None")
150
+
151
+
152
+ # first create tempo quantizer
153
+ self.tempo_quantizer = Quantizer(
154
+ config.tempo_range, config.n_tempo_bins, round_values=True
155
+ )
156
+ # add tempo tokens
157
+ self.vocab.extend(f"Tempo_{tempo}" for tempo in self.tempo_quantizer.bins)
158
+
159
+ # now add program tokens
160
+ for i in range(128):
161
+ self.vocab.append(f"Program_{i}")
162
+
163
+ # add program for drums
164
+ self.vocab.append(f"Program_Drums")
165
+
166
+ # add inactive state for program
167
+ self.vocab.append(f"Program_inactive")
168
+
169
+ # now add pitch tokens
170
+ self.vocab.extend(f"Pitch_{pitch}" for pitch in range(128))
171
+
172
+ # add pitch tokens for drums
173
+ self.vocab.extend(f"Pitch_Drum{pitch}" for pitch in range(128))
174
+
175
+ # add inactive state for pitch
176
+ self.vocab.append(f"Pitch_inactive")
177
+
178
+ # now add coarse onset tokens
179
+ self.vocab.extend(f"Onset_{i}" for i in range(0, self.n_beats * self.config.ticks_per_beat, config.coarse_ticks_per_beat))
180
+ # add inactive state for onset
181
+ self.vocab.append(f"Onset_inactive")
182
+
183
+ # add onset micro
184
+ self.vocab.extend(f"Microtiming_{i}" for i in range(self.config.coarse_ticks_per_beat))
185
+ # add inactive state for onset micro
186
+ self.vocab.append(f"Microtiming_inactive")
187
+
188
+ # now add offset tokens
189
+ self.vocab.extend(f"Offset_{i}" for i in range(0, (self.n_beats + 1) * self.config.ticks_per_beat, config.coarse_ticks_per_beat))
190
+ # add inactive state for offset
191
+ self.vocab.append(f"Offset_inactive")
192
+
193
+ # now add duration tokens
194
+ # we use fractions from 1/32 to 4/1, in powers of 2
195
+ # 32 ticks
196
+ thirtysecond_ticks = (self.config.ticks_per_beat * 4) // 32
197
+ fourbar_ticks = (self.config.ticks_per_beat * self.n_beats)
198
+
199
+ ticks = thirtysecond_ticks
200
+ while ticks <= fourbar_ticks:
201
+ # add duration token
202
+ self.vocab.append(f"Duration_{ticks}")
203
+ # multiply by 2
204
+ ticks *= 2
205
+
206
+ self.durations = [int(t.split("_")[-1]) for t in self.vocab if t.startswith("Duration_")]
207
+
208
+ # add inactive state for duration
209
+ self.vocab.append(f"Duration_inactive")
210
+
211
+ # then create velocity quantizer
212
+ self.velocity_quantizer = Quantizer(
213
+ (1, 127), config.n_velocity_bins, round_values=True
214
+ )
215
+ # add velocity tokens
216
+ self.vocab.extend(f"Velocity_{v}" for v in self.velocity_quantizer.bins)
217
+ # add inactive state for velocity
218
+ self.vocab.append(f"Velocity_inactive")
219
+
220
+ self.event_attribute_order = [
221
+ "Program",
222
+ "Pitch",
223
+ "Onset",
224
+ "Microtiming",
225
+ "Offset",
226
+ "Duration",
227
+ "Velocity",
228
+ ]
229
+
230
+ self.token_to_idx = {token: idx for idx, token in enumerate(self.vocab)}
231
+
232
+ def remove_special_tokens(self, tokens: List[str]) -> List[str]:
233
+ """Remove special tokens from the token list."""
234
+ special_tokens = ["BOS_None", "EOS_None", "SEP_None", "PAD_None"]
235
+ return [token for token in tokens if token not in special_tokens]
236
+
237
+ def get_inactive_note_tokens(self):
238
+ # get inactive note attributes
239
+ program_token = f"Program_inactive"
240
+ pitch_token = f"Pitch_inactive"
241
+ onset_coarse_token = f"Onset_inactive"
242
+ onset_fine_token = f"Microtiming_inactive"
243
+ offset_token = f"Offset_inactive"
244
+ duration_token = f"Duration_inactive"
245
+ velocity_token = f"Velocity_inactive"
246
+ # create note dict
247
+ return [program_token, pitch_token, onset_coarse_token, onset_fine_token, offset_token, duration_token, velocity_token]
248
+
249
+ def get_closest_duration(self, duration: float) -> int:
250
+ """Get the closest duration in self.durations, round down."""
251
+ return min(self.durations, key=lambda x: abs(x - duration))
252
+
253
+ def get_note_tokens(self, note, program, is_drums):
254
+ # get note attributes
255
+ program_token = f"Program_{program}" if not is_drums else f"Program_Drums"
256
+ pitch_token = f"Pitch_{note.pitch}" if not is_drums else f"Pitch_Drum{note.pitch}"
257
+ onset_coarse_token = f"Onset_{int(self.config.coarse_ticks_per_beat * (note.start // self.config.coarse_ticks_per_beat))}"
258
+ onset_fine_token = f"Microtiming_{int(note.start % self.config.coarse_ticks_per_beat)}"
259
+ offset_token = f"Offset_{min(int(self.config.coarse_ticks_per_beat * (note.end // self.config.coarse_ticks_per_beat)), self.n_beats * self.config.ticks_per_beat)}"
260
+ #
261
+ duration = self.get_closest_duration(note.end - note.start)
262
+ # get nearest duration
263
+ duration_token = f"Duration_{duration}"
264
+ velocity_token = f"Velocity_{self.velocity_quantizer.quantize(note.velocity)}"
265
+ # create note dict
266
+ return [program_token, pitch_token, onset_coarse_token, onset_fine_token, offset_token, duration_token, velocity_token]
267
+
268
+ def tokens_to_ids(self, tokens):
269
+ return super().tokens_to_ids(tokens)
270
+
271
+ def ids_to_tokens(self, ids):
272
+ return super().ids_to_tokens(ids)
273
+
274
+ def midi_to_token_ids(self, midi, shuffle_events=True):
275
+ """Convert a MIDI score to token IDs."""
276
+ tokens = self.midi_to_tokens(midi, shuffle_events)
277
+ return self.tokens_to_ids(tokens)
278
+
279
+ def token_ids_to_midi(self, token_ids):
280
+ """Convert token IDs back to a MIDI score."""
281
+ tokens = self.ids_to_tokens(token_ids)
282
+ return self.tokens_to_midi(tokens)
283
+
284
+ def midi_to_tokens(self, midi, shuffle_events=True):
285
+ assert midi.note_num() > 0, "MIDI file must contain at least one note"
286
+ assert midi.note_num() <= self.config.n_events, "MIDI file must contain less than n_events notes"
287
+ # first resample the midi to the ticks per beat
288
+ midi = midi.copy().resample(self.config.ticks_per_beat)
289
+
290
+ midi = crop_sm(midi, self.n_beats)
291
+ # assert that the time signature is 4/4
292
+ time_signature = midi.time_signatures[-1]
293
+ if time_signature.numerator != 4 or time_signature.denominator != 4:
294
+ raise ValueError(
295
+ "Only 4/4 time signature is supported for Tanja tokenizer."
296
+ )
297
+ # get tempo
298
+ tempo = midi.tempos[-1].qpm if len(midi.tempos) > 0 else 120
299
+ # quantize tempo
300
+ tempo_token = f"Tempo_{self.tempo_quantizer.quantize(tempo)}"
301
+ note_tokens = []
302
+ # sort tracks by program number, is_drum,
303
+ for track in midi.tracks:
304
+ is_drums = track.is_drum
305
+ program_nr = track.program
306
+ for note in track.notes:
307
+ # get note attributes
308
+ note_tokens.append(self.get_note_tokens(note, program_nr, is_drums))
309
+ # sort note tokens
310
+ note_tokens = sorted(note_tokens,key=lambda x: x)
311
+ # shuffle note tokens
312
+ # we'll
313
+ n_inactive_notes = self.config.n_events - len(note_tokens)
314
+ # add inactive notes
315
+ for i in range(n_inactive_notes):
316
+ note_tokens.append(self.get_inactive_note_tokens())
317
+ if shuffle_events:
318
+ note_tokens = random.sample(note_tokens, len(note_tokens))
319
+
320
+ def flatten(lst):
321
+ return [item for sublist in lst for item in sublist]
322
+ # now we have the note tokens, we can create the final token list
323
+ tokens = [tempo_token, *flatten(note_tokens)]
324
+ assert tokens[0].startswith("Tempo_"), "First token must be a tempo token"
325
+ return tokens
326
+
327
+ def get_prob_mask(self,idx):
328
+ # get last token
329
+ if idx == 0:
330
+ return [1 if token.startswith("Tempo_") else 0 for token in self.vocab]
331
+ else:
332
+ attr_index = (idx-1) % len(self.event_attribute_order)
333
+ attr_str = self.event_attribute_order[attr_index]
334
+ return [1 if token.startswith(attr_str) else 0 for token in self.vocab]
335
+
336
+ # last_token = tokens[-1]
337
+ # attr_str = last_token.split("_")[0]
338
+ # # if token is BOS, return mask which has all tempo tokens to 1 and rest to 0
339
+ # if attr_str == "BOS":
340
+ # return [1 if token.startswith("Tempo_") else 0 for token in tokens]
341
+ # # if token is Tempo, return mask which has all Program tokens to 1 and rest to 0
342
+ # elif attr_str == "Tempo":
343
+ # return [1 if token.startswith("Program_") else 0 for token in tokens]
344
+ # # if last token is in event attribute order
345
+ # elif attr_str in self.event_attribute_order:
346
+ # # get index of last token
347
+ # idx = self.event_attribute_order.index(attr_str)
348
+ # # get next attribute
349
+ # next_attr = self.event_attribute_order[(idx + 1) % len(self.event_attribute_order)]
350
+ # # return mask which has all next attributes to 1 and rest to 0
351
+ # return [1 if token.startswith(next_attr) else 0 for token in tokens]
352
+ # else:
353
+ # raise ValueError(f"Unknown token type: {last_token}")
354
+
355
+ def tokens_to_midi(self, tokens):
356
+ # make copy of tokens
357
+ tokens = tokens.copy()
358
+ tokens = self.remove_special_tokens(tokens)
359
+ # create score
360
+ midi = symusic.Score()
361
+ # set to ticks per beat
362
+ midi = midi.resample(self.config.ticks_per_beat)
363
+
364
+ # set tempo
365
+ tempo_token = tokens.pop(0)
366
+ tempo = int(tempo_token.split("_")[-1])
367
+ midi.tempos = [symusic.Tempo(qpm=tempo, time=0)]
368
+
369
+ # set time signature
370
+ midi.time_signatures.append(symusic.TimeSignature(numerator=4, denominator=4, time=0))
371
+
372
+ program_notes = {}
373
+
374
+ while len(tokens) > 0:
375
+ # pop len(self.event_attribute_order) tokens
376
+ note_tokens = tokens[:len(self.event_attribute_order)]
377
+ tokens = tokens[len(self.event_attribute_order):]
378
+
379
+ print(f"Note tokens: {note_tokens}")
380
+
381
+ # get note attributes
382
+ program_token = note_tokens[0]
383
+ # assert that this is a program token
384
+ assert program_token.startswith("Program_"), "First token must be a program token"
385
+ program_str = program_token.split("_")[-1]
386
+ if program_str == "inactive":
387
+ continue
388
+ program = int(program_str) if program_str != "Drums" else -1
389
+ is_drum = program_str == "Drums"
390
+ pitch_token = note_tokens[1]
391
+ # assert that this is a pitch token
392
+ assert pitch_token.startswith("Pitch_"), "Second token must be a pitch token"
393
+ pitch_str = pitch_token.split("_")[-1]
394
+ if pitch_str == "inactive":
395
+ continue
396
+ pitch = int(pitch_str) if "Drum" not in pitch_str else int(pitch_str.split("Drum")[-1])
397
+ # get onset coarse token
398
+ onset_coarse_token = note_tokens[2]
399
+ # assert that this is a onset token
400
+ assert onset_coarse_token.startswith("Onset_"), "Third token must be an onset token"
401
+ onset_coarse_str = onset_coarse_token.split("_")[-1]
402
+ if onset_coarse_str == "inactive":
403
+ continue
404
+ onset_coarse = int(onset_coarse_str)
405
+ # get onset fine token
406
+ onset_fine_token = note_tokens[3]
407
+ # assert that this is a onset token
408
+ assert onset_fine_token.startswith("Microtiming_"), "Fourth token must be an onset token"
409
+ onset_fine_str = onset_fine_token.split("_")[-1]
410
+ if onset_fine_str == "inactive":
411
+ continue
412
+ onset_fine = int(onset_fine_str)
413
+ # get offset token
414
+ offset_token = note_tokens[4]
415
+ # assert that this is a offset token
416
+ assert offset_token.startswith("Offset_"), "Fifth token must be an offset token"
417
+ offset_str = offset_token.split("_")[-1]
418
+ if offset_str == "inactive":
419
+ continue
420
+ offset = int(offset_str)
421
+ # get duration token
422
+ duration_token = note_tokens[5]
423
+ # assert that this is a duration token
424
+ assert duration_token.startswith("Duration_"), "Sixth token must be a duration token"
425
+ duration_str = duration_token.split("_")[-1]
426
+ if duration_str == "inactive":
427
+ continue
428
+ duration = int(duration_str)
429
+ # get velocity token
430
+ velocity_token = note_tokens[6]
431
+ # assert that this is a velocity token
432
+ assert velocity_token.startswith("Velocity_"), "Seventh token must be a velocity token"
433
+ velocity_str = velocity_token.split("_")[-1]
434
+ if velocity_str == "inactive":
435
+ continue
436
+ velocity = int(velocity_str)
437
+ # create note
438
+ if program not in program_notes:
439
+ program_notes[program] = []
440
+
441
+ onset_tick = onset_coarse + onset_fine
442
+ offset_tick = offset + onset_fine
443
+ duration = offset_tick - onset_tick
444
+
445
+ program_notes[program].append(
446
+ symusic.Note(
447
+ time=onset_coarse + onset_fine,
448
+ pitch=pitch,
449
+ velocity=velocity,
450
+ duration = duration,
451
+ )
452
+ )
453
+ # now sort programs by program number
454
+ program_notes = sorted(program_notes.items(), key=lambda x: x[0])
455
+ # sort program notes by start time, end time, pitch, velocity
456
+ for program, notes in program_notes:
457
+ notes.sort(key=lambda note: (note.start, note.end, note.pitch, note.velocity))
458
+ # now create tracks for each program
459
+ for program, notes in program_notes:
460
+ # create a new track
461
+ track = symusic.Track(is_drum=program == -1, program=program if program != -1 else 0)
462
+ # add notes to track
463
+ for note in notes:
464
+ track.notes.append(note)
465
+ # add track to midi
466
+ midi.tracks.append(track)
467
+ return midi
468
+
469
+ # header
470
+ # Tempo_120 Program_Drums Program_1 Program_34 Program_1 Track_None Bar_None Position_0 Offset_2 Pitch_Drum:60 Velocity ... Track_None Bar_None Postion_0 Offset_2 Pitch_60 Velocity_100 Duration_46 ... ... Track_None
471
+ @dataclass
472
+ class IrmaTokenizerConfig(TokenizerConfig):
473
+ ticks_per_beat: int
474
+ positions_per_beat : int
475
+ tempo_range: Tuple[int, int]
476
+ n_tempo_bins: int
477
+ n_velocity_bins: int
478
+ n_bars : int
479
+ duration_ranges: List[Tuple[int, int]]
480
+
481
+ def dict(self):
482
+ return {k: str(v) for k, v in asdict(self).items()}
483
+
484
+
485
+ class IrmaTokenizer(BaseTokenizer):
486
+ '''
487
+ Irma Tokenizer.
488
+ Starts with a header that contains the time signature and tempo.
489
+ Then, it contains the programs that will be involved (in arbitrary order).
490
+ Then, the body starts.
491
+ The body has one part per program, separated by the separator token.
492
+ A body part is structured as follows:
493
+ Track_None Program_0 BAR_None Position_12 Shift_2 Pitch_60 Velocity_100 Duration_...
494
+ # shift is in relation to last position
495
+ Track_None ...
496
+ We can have multiple tracks per program.
497
+ Offset is only present if needed.
498
+ Only supports 4/4 time signature.
499
+ '''
500
+
501
+ config_cls = IrmaTokenizerConfig
502
+
503
+ def __init__(self, config: IrmaTokenizerConfig):
504
+ super().__init__(config)
505
+
506
+
507
+ self.ticks_per_position = self.config.ticks_per_beat / self.config.positions_per_beat
508
+
509
+ self.vocab = []
510
+ # Special tokens
511
+ self.vocab.append("BOS_None")
512
+ self.vocab.append("EOS_None")
513
+ self.vocab.append("SEP_None")
514
+ self.vocab.append("PAD_None")
515
+ self.vocab.append("Bar_None")
516
+ self.vocab.append("Track_None")
517
+
518
+ # now add tempo tokens
519
+ self.tempo_quantizer = Quantizer(
520
+ self.config.tempo_range, self.config.n_tempo_bins, round_values=True
521
+ )
522
+ self.vocab.extend(f"Tempo_{tempo}" for tempo in self.tempo_quantizer.bins)
523
+
524
+ # Now add program tokens
525
+ for i in range(128):
526
+ self.vocab.append(f"Program_{i}")
527
+ # add program for drums
528
+ self.vocab.append(f"Program_Drums")
529
+
530
+ # Now add position tokens
531
+ positions_per_bar = 4 * config.positions_per_beat
532
+ for i in range(positions_per_bar):
533
+ self.vocab.append(f"Position_{i}")
534
+
535
+ # Now add offset tokens
536
+ n_offsets = config.ticks_per_beat / config.positions_per_beat
537
+ for i in range(1, int(n_offsets)):
538
+ self.vocab.append(f"Shift_{i}")
539
+
540
+ # Now add pitch tokens
541
+ self.vocab.extend(f"Pitch_{pitch}" for pitch in range(128))
542
+
543
+ # now add drum pitch tokens
544
+ self.vocab.extend(f"Pitch_Drum{pitch}" for pitch in range(128))
545
+
546
+ # Now add duration tokens
547
+ # durations operate as follows.
548
+ # if between 0 and 1, it is a note
549
+ #
550
+ # assert that all durations divisions are divisors of 96
551
+ for dur_range in self.config.duration_ranges:
552
+ assert self.config.ticks_per_beat % dur_range[1] == 0, f"Duration division {dur_range[1]} must be a divisor of ticks_per_beat {self.config.ticks_per_beat}"
553
+
554
+ range_start = 0
555
+
556
+ self.durations = []
557
+ for dur_range in self.config.duration_ranges:
558
+ range_end = dur_range[0]
559
+ # add all durations between range_start and range_end
560
+ range_start_ticks = range_start * self.config.ticks_per_beat
561
+ range_end_ticks = range_end * self.config.ticks_per_beat
562
+ dur_skip_ticks = self.config.ticks_per_beat / dur_range[1]
563
+ for i in range(range_start_ticks, range_end_ticks, int(dur_skip_ticks)):
564
+ self.vocab.append(f"Duration_{i}d{self.config.ticks_per_beat*4}")
565
+ self.durations.append(i)
566
+ range_start = range_end
567
+
568
+ # Now add velocity tokens
569
+ self.velocity_quantizer = Quantizer(
570
+ (1, 127), self.config.n_velocity_bins, round_values=True
571
+ )
572
+ self.vocab.extend(f"Velocity_{v}" for v in self.velocity_quantizer.bins)
573
+
574
+ # Create token to index mapping
575
+ self.token_to_idx = {token: idx for idx, token in enumerate(self.vocab)}
576
+
577
+ def midi_to_token_ids(self, midi: symusic.Score, shuffle_tracks=True) -> List[int]:
578
+ """Convert a MIDI score to token IDs."""
579
+ tokens = self.midi_to_tokens(midi, shuffle_tracks)
580
+ return self.tokens_to_ids(tokens)
581
+
582
+ def remove_special_tokens(self, tokens: List[str]) -> List[str]:
583
+ """Remove special tokens from the token list."""
584
+ special_tokens = ["BOS_None", "EOS_None", "SEP_None", "PAD_None"]
585
+ return [token for token in tokens if token not in special_tokens]
586
+
587
+ def token_ids_to_midi(self, token_ids: List[int]) -> symusic.Score:
588
+ """Convert token IDs back to a MIDI score."""
589
+ tokens = self.ids_to_tokens(token_ids)
590
+ return self.tokens_to_midi(tokens)
591
+
592
+ def get_closest_duration(self, duration: float) -> int:
593
+ """Get the closest duration in self.durations."""
594
+ return min(self.durations, key=lambda x: abs(x - duration))
595
+
596
+ def midi_to_tokens(self, midi: symusic.Score, shuffle_tracks=True) -> List[str]:
597
+ """Convert a MIDI score to tokens."""
598
+ midi = midi.copy().resample(self.config.ticks_per_beat)
599
+
600
+ tempo = midi.tempos[-1].qpm if len(midi.tempos) > 0 else 120
601
+ time_signature = midi.time_signatures[-1]
602
+ if time_signature.numerator != 4 or time_signature.denominator != 4:
603
+ raise ValueError(
604
+ "Only 4/4 time signature is supported for Irma tokenizer."
605
+ )
606
+
607
+ tempo_token = f"Tempo_{self.tempo_quantizer.quantize(tempo)}"
608
+
609
+ # shuffle tracks
610
+ tracks = [track for track in midi.tracks if len(track.notes) > 0]
611
+
612
+ if shuffle_tracks:
613
+ # shuffle tracks
614
+ tracks = random.sample(tracks, len(tracks))
615
+
616
+ program_tokens = []
617
+ track_tokens = []
618
+ for track in tracks:
619
+
620
+ if track.is_drum:
621
+ # add
622
+ program_tokens.append(f"Program_Drums")
623
+ else:
624
+ program_tokens.append(f"Program_{track.program}")
625
+
626
+ new_track_tokens = ["Track_None"]
627
+ # add bar
628
+ bar_count = -1
629
+ curr_position = -1
630
+ curr_shift = 0
631
+ notes = track.notes.copy()
632
+ notes.sort(key=lambda note: (note.start, note.pitch, note.velocity))
633
+ for note in notes:
634
+ # add bar tokens
635
+ bar_idx = note.start // (self.config.ticks_per_beat * 4)
636
+ while bar_count < bar_idx:
637
+ new_track_tokens.append("Bar_None")
638
+ bar_count += 1
639
+ curr_position = -1
640
+ curr_shift = 0
641
+
642
+ # get onset
643
+ onset = note.start
644
+
645
+ # get position
646
+ position = int(onset % (self.config.ticks_per_beat * 4) // self.ticks_per_position)
647
+ if position != curr_position:
648
+ new_track_tokens.append(f"Position_{position}")
649
+ curr_position = position
650
+ curr_shift = 0
651
+
652
+ shift = int(onset % self.ticks_per_position)
653
+ if shift != curr_shift:
654
+ new_track_tokens.append(f"Shift_{shift}")
655
+ curr_shift = shift
656
+
657
+ # get pitch
658
+ if track.is_drum:
659
+ new_track_tokens.append(f"Pitch_Drum{note.pitch}")
660
+ else:
661
+ new_track_tokens.append(f"Pitch_{note.pitch}")
662
+
663
+ # get velocity
664
+ new_track_tokens.append(f"Velocity_{self.velocity_quantizer.quantize(note.velocity)}")
665
+ # get duration
666
+
667
+ # get duration
668
+ duration = note.end - note.start
669
+
670
+ # get closest duration in self.durations
671
+ closest_duration = self.get_closest_duration(duration)
672
+
673
+ # get duration token
674
+ new_track_tokens.append(f"Duration_{closest_duration}d{self.config.ticks_per_beat*4}")
675
+
676
+ track_tokens.append(new_track_tokens)
677
+
678
+ tokens = [tempo_token, *program_tokens]
679
+
680
+ for track in track_tokens:
681
+ tokens.extend(track)
682
+
683
+ return tokens
684
+
685
+
686
+ def tokens_to_midi(self, tokens):
687
+
688
+ tokens = tokens.copy()
689
+
690
+ tokens = self.remove_special_tokens(tokens)
691
+
692
+ # assert that the first token is a tempo token
693
+ assert tokens[0].startswith("Tempo_"), "First token must be a tempo token"
694
+
695
+ tempo_token = tokens.pop(0)
696
+
697
+ # then pop program tokens until we reach the first Track_None
698
+ program_tokens = []
699
+ while tokens and not tokens[0].startswith("Track_None"):
700
+ pr_token = tokens.pop(0)
701
+ assert pr_token.startswith("Program_"), "Program token must start with Program_"
702
+ program_tokens.append(pr_token)
703
+
704
+ # now we have the program tokens, we can start processing the tracks
705
+ # first create symusic.Score object
706
+ midi = symusic.Score()
707
+ # set tick rate
708
+ midi = midi.resample(self.config.ticks_per_beat)
709
+
710
+ # set tempo
711
+ tempo = int(tempo_token.split("_")[-1])
712
+ midi.tempos = [symusic.Tempo(qpm=tempo, time=0)]
713
+
714
+ # set time signature
715
+ midi.time_signatures.append(symusic.TimeSignature(numerator=4, denominator=4, time=0))
716
+
717
+
718
+ def split_list_by_value(lst, value):
719
+ result = []
720
+ current_sublist = []
721
+
722
+ for item in lst:
723
+ if item == value:
724
+ if current_sublist: # Save the current sublist if it's not empty
725
+ result.append(current_sublist)
726
+ current_sublist = []
727
+ # Optionally add the split value to a separate list or discard it
728
+ else:
729
+ current_sublist.append(item)
730
+
731
+ if current_sublist: # Add the last sublist if it's not empty
732
+ result.append(current_sublist)
733
+
734
+ return result
735
+
736
+ # split tokens by Track_None
737
+ tokens_split_by_track = split_list_by_value(tokens, "Track_None")
738
+
739
+ # assert that we have the same number of tracks as programs
740
+ assert len(tokens_split_by_track) == len(program_tokens), "Number of tracks must be equal to number of programs"
741
+
742
+ # now create a track for each program
743
+ for track_tokens, track_program in zip(tokens_split_by_track, program_tokens):
744
+ # create a new track
745
+ track = symusic.Track(is_drum=track_program == "Program_Drums", program=int(track_program.split("_")[-1]) if track_program != "Program_Drums" else 0)
746
+ # set bar count
747
+ bar_count = -1
748
+ curr_position = 0
749
+ curr_shift = 0
750
+ for token in track_tokens:
751
+ if token.startswith("Bar_None"):
752
+ bar_count += 1
753
+ curr_position = 0
754
+ elif token.startswith("Position_"):
755
+ curr_position = int(token.split("_")[-1])
756
+ curr_shift = 0
757
+ elif token.startswith("Shift_"):
758
+ curr_shift = int(token.split("_")[-1])
759
+ elif token.startswith("Pitch_"):
760
+ pitch_str = token.split("_")[-1]
761
+ if pitch_str.startswith("Drum"):
762
+ pitch = int(pitch_str.split("Drum")[-1])
763
+ else:
764
+ pitch = int(pitch_str)
765
+ elif token.startswith("Velocity_"):
766
+ velocity = int(token.split("_")[-1])
767
+ elif token.startswith("Duration_"):
768
+ duration = int(token.split("_")[-1].split("d")[0])
769
+ # create note
770
+ note = symusic.Note(
771
+ time=int(bar_count * self.config.ticks_per_beat * 4 + curr_position * self.ticks_per_position + curr_shift),
772
+ pitch=pitch,
773
+ velocity=velocity,
774
+ duration=duration)
775
+ track.notes.append(note)
776
+ # add track to midi
777
+ midi.tracks.append(track)
778
+
779
+ return midi
tokenizer_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ticks_per_beat": 96,
3
+ "positions_per_beat": 12,
4
+ "tempo_range": [
5
+ 60,
6
+ 250
7
+ ],
8
+ "n_tempo_bins": 32,
9
+ "n_velocity_bins": 32,
10
+ "n_bars": 4,
11
+ "duration_ranges": [
12
+ [
13
+ 2,
14
+ 12
15
+ ],
16
+ [
17
+ 16,
18
+ 6
19
+ ]
20
+ ]
21
+ }