easydel

gpt-oss-20b

A model compatible with the EasyDeL JAX stack.

Overview

This checkpoint is intended to be loaded with EasyDeL on JAX (CPU/GPU/TPU). It supports sharded loading with auto_shard_model=True and configurable precision via dtype, param_dtype, and precision.

Quickstart

import easydel as ed
from jax import numpy as jnp, lax

repo_id = "EasyDeL/gpt-oss-20b"

dtype = jnp.bfloat16  # try jnp.float16 on many GPUs

model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
    repo_id,
    dtype=dtype,
    param_dtype=dtype,
    precision=lax.Precision("fastest"),
    sharding_axis_names=("dp", "fsdp", "ep", "tp", "sp"),
    sharding_axis_dims=(1, -1, 1, 1, 1),
    config_kwargs=ed.EasyDeLBaseConfigDict(
        attn_dtype=dtype,
        attn_mechanism=ed.AttentionMechanisms.RAGGED_PAGE_ATTENTION_V3,
        fsdp_is_ep_bound=True,
        sp_is_ep_bound=True,
        moe_method=ed.MoEMethods.FUSED_MOE,
    ),
    auto_shard_model=True,
    partition_axis=ed.PartitionAxis(),
)

If the repository only provides PyTorch weights, pass from_torch=True to from_pretrained(...).

Sharding & Parallelism (Multi-Device)

EasyDeL can scale to multiple devices by creating a logical device mesh. Most EasyDeL loaders use a 5D mesh:

  • dp: data parallel (replicated parameters, different batch shards)
  • fsdp: parameter sharding (memory saver; often the biggest axis)
  • ep: expert parallel (MoE; keep 1 for non-MoE models)
  • tp: tensor parallel (splits large matmuls)
  • sp: sequence parallel (splits sequence dimension)

Use sharding_axis_names=("dp","fsdp","ep","tp","sp") and choose sharding_axis_dims so that their product equals your device count. You can use -1 in sharding_axis_dims to let EasyDeL infer the remaining dimension.

Example sharding configs
# 8 devices, pure FSDP
sharding_axis_dims = (1, 8, 1, 1, 1)

# 8 devices, 2-way DP x 4-way FSDP
sharding_axis_dims = (2, 4, 1, 1, 1)

# 8 devices, 4-way FSDP x 2-way TP
sharding_axis_dims = (1, 4, 1, 2, 1)

Using via eLargeModel (ELM)

eLargeModel is a higher-level interface that wires together loading, sharding, training, and eSurge inference from a single config.

from easydel import eLargeModel

repo_id = "EasyDeL/gpt-oss-20b"

elm = eLargeModel.from_pretrained(repo_id)  # task is auto-detected
elm.set_dtype("bf16")
elm.set_sharding(axis_names=("dp", "fsdp", "ep", "tp", "sp"), axis_dims=(1, -1, 1, 1, 1))

model = elm.build_model()
# Optional: build an inference engine
# engine = elm.build_esurge()
ELM YAML config example
model:
  name_or_path: "EasyDeL/gpt-oss-20b"

loader:
  dtype: bf16
  param_dtype: bf16

sharding:
  axis_dims: [1, -1, 1, 1, 1]
  auto_shard_model: true

Features

EasyDeL:

  • JAX native implementation and sharded execution
  • Configurable attention backends via AttentionMechanisms.*
  • Precision control via dtype, param_dtype, and precision

Installation

pip install easydel

Links

Supported Tasks

  • CausalLM

Limitations

  • Refer to the original model card for training data, evaluation, and intended use.

License

EasyDeL is released under the Apache-2.0 license. The license for this model's weights may differ; please consult the original repository.

Citation

@misc{Zare Chavoshi_2023,
    title={EasyDeL: An open-source library for enhancing and streamlining the training process of machine learning models},
    url={https://github.com/erfanzar/EasyDeL},
    author={Zare Chavoshi, Erfan},
    year={2023}
}
Downloads last month
354
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including EasyDeL/gpt-oss-20b