File size: 25,378 Bytes
0558aa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random

import numpy as np
import torch

from nemo.utils import AppState, logging

try:
    from apex.transformer.log_util import set_logging_level

    HAVE_APEX = True

except (ImportError, ModuleNotFoundError):

    HAVE_APEX = False

try:
    from megatron.core import tensor_parallel
    from megatron.core.parallel_state import (
        RankGenerator,
        get_pipeline_model_parallel_rank,
        set_expert_model_parallel_rank,
        set_expert_model_parallel_world_size,
        set_pipeline_model_parallel_rank,
        set_pipeline_model_parallel_world_size,
        set_tensor_model_parallel_rank,
        set_tensor_model_parallel_world_size,
    )

    HAVE_MEGATRON_CORE = True

except (ImportError, ModuleNotFoundError):

    HAVE_MEGATRON_CORE = False

try:
    from megatron.core.num_microbatches_calculator import (
        ConstantNumMicroBatchesCalculator,
        get_current_global_batch_size,
        get_micro_batch_size,
        get_num_microbatches,
        init_num_microbatches_calculator,
    )

    MCORE_MB_CALCULATOR = True

except (ImportError, ModuleNotFoundError):
    logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")

    if HAVE_APEX:
        from apex.transformer.microbatches import ConstantNumMicroBatches as ConstantNumMicroBatchesCalculator
        from apex.transformer.pipeline_parallel.utils import (
            get_current_global_batch_size,
            get_micro_batch_size,
            get_num_microbatches,
        )
        from apex.transformer.pipeline_parallel.utils import (
            setup_microbatch_calculator as init_num_microbatches_calculator,
        )

    MCORE_MB_CALCULATOR = False


def initialize_model_parallel_for_nemo(
    world_size,
    global_rank,
    local_rank,
    tensor_model_parallel_size=1,
    expert_model_parallel_size=1,
    expert_tensor_parallel_size=None,
    pipeline_model_parallel_size=1,
    virtual_pipeline_model_parallel_size=None,
    pipeline_model_parallel_split_rank=None,
    pipeline_model_parallel_comm_backend=None,
    context_parallel_size=1,
    encoder_tensor_model_parallel_size=0,
    encoder_pipeline_model_parallel_size=0,
    micro_batch_size=None,
    global_batch_size=None,
    rampup_batch_size=None,
    use_fp8=False,
    init_mpi_proc_group=False,
    seed=1234,
    apex_transformer_log_level=30,
    use_tp_pp_dp_mapping=False,
    use_te_rng_tracker=False,
    num_distributed_optimizer_instances=1,
    nccl_communicator_config_path=None,
    use_sharp=False,
    use_gloo_process_groups: bool = True,
):
    """Initialize model parallel groups in NeMo."""
    assert (
        pipeline_model_parallel_split_rank is None or pipeline_model_parallel_split_rank == 0
    ), "pipeline_model_parallel_split_rank is deprecated."
    assert encoder_pipeline_model_parallel_size == 0 and (
        encoder_tensor_model_parallel_size == 0 or encoder_tensor_model_parallel_size == tensor_model_parallel_size
    ), (
        "encoder_pipeline_model_parallel_size is temporarily "
        "unavailable. We are working on a refactoring to add it back."
    )

    # updating NeMo globals
    app_state = AppState()
    app_state.global_rank = global_rank
    app_state.world_size = world_size
    app_state.local_rank = local_rank
    app_state.use_tp_pp_dp_mapping = use_tp_pp_dp_mapping
    app_state.expert_model_parallel_size = expert_model_parallel_size
    app_state.tensor_model_parallel_size = tensor_model_parallel_size
    app_state.pipeline_model_parallel_size = pipeline_model_parallel_size
    app_state.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size
    app_state.context_parallel_size = context_parallel_size
    app_state.encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size
    app_state.encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size
    app_state.pipeline_model_parallel_comm_backend = pipeline_model_parallel_comm_backend
    app_state.use_fp8 = use_fp8
    app_state.use_sharp = use_sharp
    app_state.init_mpi_proc_group = init_mpi_proc_group
    app_state.expert_tensor_parallel_size = expert_tensor_parallel_size
    app_state.num_distributed_optimizer_instances = num_distributed_optimizer_instances
    app_state.nccl_communicator_config_path = nccl_communicator_config_path
    app_state.use_gloo_process_groups = use_gloo_process_groups
    (
        app_state.tensor_model_parallel_rank,
        app_state.pipeline_model_parallel_rank,
        app_state.expert_model_parallel_rank,
        app_state.expert_tensor_parallel_rank,
        app_state.model_parallel_size,
        app_state.data_parallel_size,
        app_state.pipeline_model_parallel_split_rank,
        app_state.virtual_pipeline_model_parallel_rank,
    ) = fake_initialize_model_parallel(
        world_size=world_size,
        rank=global_rank,
        tensor_model_parallel_size_=tensor_model_parallel_size,
        pipeline_model_parallel_size_=pipeline_model_parallel_size,
        virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size,
        pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank,
        context_parallel_size_=context_parallel_size,
        expert_model_parallel_size_=expert_model_parallel_size,
        expert_tensor_parallel_size_=expert_tensor_parallel_size,
        encoder_tensor_model_parallel_size_=encoder_tensor_model_parallel_size,
        encoder_pipeline_model_parallel_size_=encoder_pipeline_model_parallel_size,
        use_tp_pp_dp_mapping=use_tp_pp_dp_mapping,
    )

    # update apex.transformer globals
    set_tensor_model_parallel_world_size(app_state.tensor_model_parallel_size)
    set_tensor_model_parallel_rank(app_state.tensor_model_parallel_rank)

    set_expert_model_parallel_world_size(app_state.expert_model_parallel_size)
    set_expert_model_parallel_rank(app_state.expert_model_parallel_rank)

    set_pipeline_model_parallel_world_size(
        app_state.pipeline_model_parallel_size + app_state.encoder_pipeline_model_parallel_size
    )
    set_pipeline_model_parallel_rank(app_state.pipeline_model_parallel_rank)

    tensor_parallel.random.initialize_rng_tracker(use_te_rng_tracker=use_te_rng_tracker)
    if seed is not None:
        # @chcui not setting seed is for model conversion. always set seed for training/inference.
        _set_random_seed(seed)

    if global_batch_size and micro_batch_size is not None:
        # TODO: add rampup_batch_size here when we have it implemented
        if MCORE_MB_CALCULATOR:
            from megatron.core.num_microbatches_calculator import _GLOBAL_NUM_MICROBATCHES_CALCULATOR

            if _GLOBAL_NUM_MICROBATCHES_CALCULATOR is None:
                init_num_microbatches_calculator(
                    rank=global_rank,
                    global_batch_size=global_batch_size,
                    micro_batch_size=micro_batch_size,
                    data_parallel_size=app_state.data_parallel_size,
                    rampup_batch_size=rampup_batch_size,
                    decrease_batch_size_if_needed=False,
                )
            else:
                if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator):
                    assert get_current_global_batch_size() == global_batch_size
                    assert get_micro_batch_size() == micro_batch_size
                    assert get_num_microbatches() == global_batch_size // (
                        micro_batch_size * app_state.data_parallel_size
                    )
                else:
                    raise Exception("Microbatch calculator already initialized.")
        else:
            from apex.transformer.pipeline_parallel.utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR

            if _GLOBAL_NUM_MICROBATCHES_CALCULATOR is None:
                init_num_microbatches_calculator(
                    rank=global_rank,
                    global_batch_size=global_batch_size,
                    micro_batch_size=micro_batch_size,
                    data_parallel_size=app_state.data_parallel_size,
                    rampup_batch_size=rampup_batch_size,
                    decrease_batch_size_if_needed=False,
                )
            else:
                if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator):
                    assert get_current_global_batch_size() == global_batch_size
                    assert get_micro_batch_size() == micro_batch_size
                    assert get_num_microbatches() == global_batch_size // (
                        micro_batch_size * app_state.data_parallel_size
                    )
                else:
                    raise Exception("Microbatch calculator already initialized.")

    app_state._is_megatron_initialized = True

    if HAVE_APEX:
        set_logging_level(apex_transformer_log_level)


def _set_random_seed(seed_):
    """Set random seed for reproducability."""
    if seed_ is not None and seed_ > 0:
        # Ensure that different pipeline MP stages get different seeds.
        seed = seed_ + (100 * get_pipeline_model_parallel_rank())
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.device_count() > 0:
            tensor_parallel.model_parallel_cuda_manual_seed(seed)
    else:
        raise ValueError('Seed ({}) should be a positive integer.'.format(seed_))


def set_jit_fusion_options():
    """Set PyTorch JIT layer fusion options."""
    # set flags if we are using the 21.10 container
    if torch.__version__ == "1.10.0a0+0aef44c":
        # nvfuser
        torch._C._jit_set_profiling_executor(True)
        torch._C._jit_set_profiling_mode(True)
        torch._C._jit_override_can_fuse_on_cpu(False)
        torch._C._jit_override_can_fuse_on_gpu(False)
        torch._C._jit_set_texpr_fuser_enabled(False)
        torch._C._jit_set_nvfuser_enabled(True)
        torch._C._debug_set_autodiff_subgraph_inlining(False)


def fake_initialize_model_parallel(
    world_size,
    rank,
    tensor_model_parallel_size_,
    pipeline_model_parallel_size_,
    pipeline_model_parallel_split_rank_=None,
    virtual_pipeline_model_parallel_size_=None,
    expert_model_parallel_size_=1,
    expert_tensor_parallel_size_=None,
    context_parallel_size_=1,
    encoder_tensor_model_parallel_size_=0,
    encoder_pipeline_model_parallel_size_=0,
    use_tp_pp_dp_mapping=False,
):
    """
    Fake initialize model data parallel groups so that we can instantiate model parallel
    models before DDP is initialized. This is needed because PTL execution flow is init
    model, init trainer -> call trainer.fit(model). DDP is initialized during .fit.
    This function is taken from megatron.core.parallel_state and modified so that the
    distributed groups are not created.
    We only need the tensor parallel and pipeline parallel ranks to instantiate the model.

    Arguments:
        tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
        pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
        context_parallel_size: number of GPUs used to parallelize tokens of each input.

    Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
    create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
    and 8 data-parallel groups as:
        8 data_parallel groups:
            [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
        8 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
        4 pipeline model-parallel groups:
            [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """

    assert pipeline_model_parallel_split_rank_ is None, "pipeline_model_parallel_split_rank is deprecated."
    assert encoder_pipeline_model_parallel_size_ == 0 and (
        encoder_tensor_model_parallel_size_ == 0 or encoder_tensor_model_parallel_size_ == tensor_model_parallel_size_
    ), (
        "encoder_pipeline_model_parallel_size is temporarily "
        "unavailable. We are working on a refactoring to add it back."
    )
    # Get world size and rank. Ensure some consistencies.
    tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
    pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
    model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size
    context_parallel_size = min(context_parallel_size_, world_size)

    if encoder_pipeline_model_parallel_size_ is None:
        encoder_pipeline_model_parallel_size = 0
    else:
        encoder_pipeline_model_parallel_size = encoder_pipeline_model_parallel_size_

    if encoder_tensor_model_parallel_size_ == 0 and encoder_pipeline_model_parallel_size_ > 0:
        encoder_tensor_model_parallel_size = tensor_model_parallel_size
    else:
        encoder_tensor_model_parallel_size = encoder_tensor_model_parallel_size_

    if encoder_tensor_model_parallel_size > 0:
        assert encoder_pipeline_model_parallel_size > 0
        assert (
            encoder_tensor_model_parallel_size <= tensor_model_parallel_size
        ), "We do not support encoders with more TP than the decoder."

    encoder_model_size = (
        encoder_tensor_model_parallel_size * encoder_pipeline_model_parallel_size * context_parallel_size
    )
    decoder_model_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
    total_model_size = encoder_model_size + decoder_model_size

    assert world_size % total_model_size == 0, (
        f'world_size: {world_size} must be divisible by total world_size: '
        f'(decoder_)tensor_model_parallel_size {tensor_model_parallel_size} '
        f'* (decoder_)pipeline_model_parallel_size {pipeline_model_parallel_size} '
        f'* (decoder_)context_parallel_size {context_parallel_size} + '
        f'encoder_tensor_model_parallel_size {encoder_tensor_model_parallel_size} '
        f'* encoder_pipeline_model_parallel_size {encoder_pipeline_model_parallel_size} '
        f'* context_parallel_size {context_parallel_size}'
    )
    data_parallel_size = world_size // total_model_size

    encoder_world_size = encoder_model_size * data_parallel_size
    decoder_world_size = decoder_model_size * data_parallel_size
    assert encoder_world_size + decoder_world_size == world_size

    virtual_pipeline_model_parallel_rank = None
    if virtual_pipeline_model_parallel_size_ is not None:
        virtual_pipeline_model_parallel_rank = 0

    if encoder_world_size > 0:
        encoder_rank_generator = RankGenerator(
            tp=encoder_tensor_model_parallel_size,
            ep=1,
            dp=data_parallel_size,
            pp=encoder_pipeline_model_parallel_size,
            cp=context_parallel_size,
            order='tp-cp-ep-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp',
            rank_offset=0,
        )
    else:
        encoder_rank_generator = None

    decoder_rank_generator = RankGenerator(
        tp=tensor_model_parallel_size,
        ep=1,
        dp=data_parallel_size,
        pp=pipeline_model_parallel_size,
        cp=context_parallel_size,
        order='tp-cp-ep-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp',
        rank_offset=encoder_world_size,
    )
    # Build expert rank generator
    if expert_tensor_parallel_size_ is None:
        expert_tensor_parallel_size_ = tensor_model_parallel_size
    expert_tensor_model_pipeline_parallel_size = (
        expert_tensor_parallel_size_ * expert_model_parallel_size_ * pipeline_model_parallel_size
    )
    expert_data_parallel_size = decoder_world_size // expert_tensor_model_pipeline_parallel_size
    if decoder_world_size % expert_tensor_model_pipeline_parallel_size != 0:
        raise RuntimeError(
            f"decoder world_size ({decoder_world_size}) is not divisible by "
            f"expert_tensor_model_pipeline_parallel size ({expert_tensor_model_pipeline_parallel_size})"
        )

    expert_decoder_rank_generator = RankGenerator(
        tp=expert_tensor_parallel_size_,
        ep=expert_model_parallel_size_,
        dp=expert_data_parallel_size,
        pp=pipeline_model_parallel_size,
        cp=1,
        order='tp-cp-ep-pp-dp' if use_tp_pp_dp_mapping else 'tp-cp-ep-dp-pp',
        rank_offset=encoder_world_size,
    )

    assert (
        not use_tp_pp_dp_mapping
        or pipeline_model_parallel_size == 1
        or expert_data_parallel_size == data_parallel_size
    ), "When not using pp-last rank ordering, the data parallel size of the attention and moe layers must be the same"

    assert decoder_rank_generator.get_ranks("pp") == expert_decoder_rank_generator.get_ranks(
        "pp"
    ), f"Pipeline parallel groups are expected to be the same for Non-Expert and Expert part, \
    but got {decoder_rank_generator.get_ranks('pp')} and {expert_decoder_rank_generator.get_ranks('pp')}"

    def generator_wrapper(group_type, is_expert=False, **kwargs):
        from itertools import cycle

        """The `RankGenerator` class produces a hyper-rectangle for a given set of
        tensor, pipeline, data, expert, and context parallelism. If we have an encoder,
        in addition to the default decoder, we essentially instantiate two `RankGenerator`
        classes to construct the parallelism for each module separately, and we then have
        to stitch them together for the right groups. For now, this means pp and tp-pp."""
        if is_expert:
            d_ranks = expert_decoder_rank_generator.get_ranks(group_type, **kwargs)
        else:
            d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs)
        if encoder_rank_generator is None:
            for x in d_ranks:
                yield x
            return
        e_ranks = encoder_rank_generator.get_ranks(group_type, **kwargs)
        if group_type == 'pp':
            # Map 1 encoder tp rank to several decoder tp ranks, because
            # these won't be the same size.
            for x, y in zip(cycle(e_ranks), d_ranks):
                yield x + y
        elif group_type == 'tp-pp':
            # For this group, we can just return the concatenated
            # groups together, because their sizes are the same.
            assert len(e_ranks) == len(d_ranks)
            for x, y in zip(e_ranks, d_ranks):
                yield x + y
        else:
            for x in e_ranks:
                yield x
            for x in d_ranks:
                yield x

    # Build the data-parallel groups.
    all_data_parallel_group_ranks_with_cp = []
    for ranks in generator_wrapper('dp'):
        if rank in ranks:
            data_parallel_group = list(ranks)
            logging.info(f'Rank {rank} has data parallel group : {data_parallel_group}')

    for ranks_with_cp in generator_wrapper('dp-cp'):
        all_data_parallel_group_ranks_with_cp.append(ranks_with_cp)
        if rank in ranks_with_cp:
            data_parallel_group_with_cp = ranks_with_cp
            logging.info(
                f'Rank {rank} has combined group of data parallel and context parallel : {data_parallel_group_with_cp}'
            )

    data_parallel_rank = data_parallel_group.index(rank)
    logging.info(
        f'All data parallel group ranks with context parallel combined: {all_data_parallel_group_ranks_with_cp}'
    )
    logging.info(f'Ranks {rank} has data parallel rank: {data_parallel_rank}')

    # Build the context-parallel groups.
    all_context_parallel_group_ranks = []
    for ranks in generator_wrapper('cp'):
        all_context_parallel_group_ranks.append(ranks)
        if rank in ranks:
            context_parallel_group = ranks
            logging.info(f'Rank {rank} has context parallel group: {context_parallel_group}')

    context_parallel_rank = context_parallel_group.index(rank)
    logging.info(f'All context parallel group ranks: {all_context_parallel_group_ranks}')
    logging.info(f'Ranks {rank} has context parallel rank: {context_parallel_rank}')

    # Build the model-parallel groups.
    all_model_parallel_group_ranks = []
    for ranks in generator_wrapper('tp-pp'):
        all_model_parallel_group_ranks.append(ranks)
        if rank in ranks:
            logging.info(f'Rank {rank} has model parallel group: {list(ranks)}')
    logging.info(f'All model parallel group ranks: {all_model_parallel_group_ranks}')

    # Build the tensor model-parallel groups.
    all_tensor_model_parallel_group_ranks = []
    tensor_model_parallel_group = None
    for ranks in generator_wrapper('tp'):
        all_tensor_model_parallel_group_ranks.append(ranks)
        if rank in ranks:
            tensor_model_parallel_group = ranks
            logging.info(f'Rank {rank} has tensor model parallel group: {tensor_model_parallel_group}')

    tensor_model_parallel_rank = tensor_model_parallel_group.index(rank)

    logging.info(f'All tensor model parallel group ranks: {all_tensor_model_parallel_group_ranks}')
    logging.info(f'Rank {rank} has tensor model parallel rank: {tensor_model_parallel_rank}')

    # EP rank
    expert_model_parallel_rank = 0
    if expert_model_parallel_size_ is not None and expert_model_parallel_size_ > 1:
        all_expert_model_parallel_ranks = []
        for ranks in generator_wrapper('ep', is_expert=True):
            all_expert_model_parallel_ranks.append(ranks)
            if rank in ranks:
                expert_model_parallel_rank = list(ranks).index(rank)
        logging.info(f'All expert model parallel group ranks: {all_expert_model_parallel_ranks}')
        logging.info(f'Rank {rank} has expert model parallel rank: {expert_model_parallel_rank}')

    # ETP
    expert_tensor_parallel_rank = 0
    if expert_tensor_parallel_size_ is not None and expert_tensor_parallel_size_ > 1:
        all_expert_tensor_parallel_ranks = []
        for ranks in generator_wrapper('tp', is_expert=True):
            all_expert_tensor_parallel_ranks.append(ranks)
            if rank in ranks:
                expert_tensor_parallel_rank = list(ranks).index(rank)
        logging.info(f'All expert tensor parallel group ranks: {all_expert_tensor_parallel_ranks}')
        logging.info(f'Rank {rank} has expert tensor parallel rank: {expert_tensor_parallel_rank}')

    # Build the pipeline model-parallel groups and embedding groups
    # (first and last rank in each pipeline model-parallel group).
    all_pipeline_model_parallel_group_ranks = []
    all_embedding_group_ranks = []
    pipeline_model_parallel_group = None
    embedding_group = None
    embedding_rank = None
    for ranks in generator_wrapper('pp'):
        all_pipeline_model_parallel_group_ranks.append(ranks)
        if rank in ranks:
            pipeline_model_parallel_group = ranks
            logging.info(f'Rank {rank} has pipeline model parallel group: {pipeline_model_parallel_group}')

        # Setup embedding group (to exchange gradients between
        # first and last stages).
        if len(ranks) > 1:
            embedding_ranks = [ranks[0], ranks[-1]]
            all_embedding_group_ranks.append(embedding_ranks)
        else:
            embedding_ranks = ranks
            all_embedding_group_ranks.append(list(embedding_ranks))
        if rank in embedding_ranks:
            embedding_group = list(embedding_ranks)
            logging.info(f'Rank {rank} has embedding group: {embedding_group}')

    pipeline_model_parallel_rank = pipeline_model_parallel_group.index(rank)
    if embedding_group is not None:
        embedding_rank = embedding_group.index(rank)

    logging.info(f'All pipeline model parallel group ranks: {all_pipeline_model_parallel_group_ranks}')
    logging.info(f'Rank {rank} has pipeline model parallel rank {pipeline_model_parallel_rank}')
    logging.info(f'All embedding group ranks: {all_pipeline_model_parallel_group_ranks}')
    logging.info(f'Rank {rank} has embedding rank: {embedding_rank}')

    return (
        tensor_model_parallel_rank,
        pipeline_model_parallel_rank,
        expert_model_parallel_rank,
        expert_tensor_parallel_rank,
        model_parallel_size,
        data_parallel_size,
        pipeline_model_parallel_split_rank_,
        virtual_pipeline_model_parallel_rank,
    )