-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add AutoSP example #999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
neeldani
wants to merge
1
commit into
deepspeedai:master
Choose a base branch
from
neeldani:autosp-example
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add AutoSP example #999
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| *.log | ||
| *.pyc | ||
| logs | ||
| *. | ||
| *.pt | ||
| output |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| # AutoSP Benchmarking Examples | ||
|
|
||
| This directory contains AutoSP benchmarking examples that demonstrate model compilation and optimization techniques using DeepSpeed and HuggingFace Accelerate. The example script show four compilation modes (AutoSP and baselines) for training large language models: | ||
|
|
||
| | Mode | Parallelism Strategy | Execution Backend | | ||
| |------|----------------------|-------------------| | ||
| | **eager** | Ulysses DistributedAttention | PyTorch Eager | | ||
| | **compile** | Ulysses DistributedAttention | PyTorch Inductor | | ||
| | **autosp** | Automatic Sequence Parallelism | AutoSP Compiler | | ||
| | **ringattn** | RingAttention-style Sequence Parallelism | PyTorch Inductor | | ||
|
|
||
| ## Files in this Directory | ||
|
|
||
| - **run.py**: Benchmarking script with an option to choose either of the 4 compilation modes listed above | ||
| - **run_autosp.sh**: Launcher script that configures training runs across multiple GPUs using Hugging Face Accelerate | ||
| - **sp_dp_registry.py**: Sequence Parallel and Data Parallel mesh management utilities | ||
| - **distributed_attention.py**: Ulysses-styled sequence paralllelism which can be plugged in as an attention backend for HuggingFace | ||
| - **ring_attention.py**: Ring Attention algorithm implementation which can be plugged in as an attention backend for HuggingFace | ||
| - **configs/**: Training configuration templates for different model sizes and scenarios | ||
| - **correctness/**: Correctness validation suite for AutoSP | ||
| - **correctness_run.py**: Runs training for a specific configuration (compile mode, sequence parallel size, ZeRO stage) and saves per-rank losses to a JSON file for comparison | ||
| - **correctness.sh**: Launcher script that orchestrates correctness testing across multiple configurations, running both baseline (compiled Ulysses) and AutoSP modes | ||
| - **validator.py**: Compares per-rank losses between AutoSP and baseline to verify numerical correctness within a configurable threshold | ||
|
|
||
| ## Setup Guide | ||
|
|
||
| Quick start guide to set up the AutoSP example. This example demonstrates usage of AutoSP with [HuggingFace Accelerate](https://huggingface.co/docs/accelerate/index) for distributed training across multiple GPUs. | ||
|
|
||
| ### Install dependencies | ||
|
|
||
| ```bash | ||
| pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 | ||
| ``` | ||
|
|
||
| ```bash | ||
| pip install \ | ||
| transformers==4.50.3 \ | ||
| tokenizers \ | ||
| huggingface-hub \ | ||
| safetensors \ | ||
| datasets \ | ||
| accelerate \ | ||
| scipy \ | ||
| tqdm \ | ||
| pyyaml | ||
| ``` | ||
|
|
||
| ## Benchmarking | ||
|
|
||
| The `benchmarks/autosp/` directory contains for benchmarking scripts: | ||
|
|
||
| ```bash | ||
| cd benchmarks/autosp | ||
| ``` | ||
|
|
||
| #### Run autosp on 2 GPUs | ||
| ```bash | ||
| ./run_autosp.sh --compile autosp --batch-size 1 --seq-length 64 --sp-size 2 --num-layers 1 --steps 1 --deterministic | ||
| ``` | ||
|
|
||
| #### Run eager mode ulysses on 2 GPUs | ||
| ```bash | ||
| ./run_autosp.sh --compile eager --batch-size 1 --seq-length 64 --sp-size 2 --num-layers 1 --steps 1 --deterministic | ||
| ``` | ||
|
|
||
| #### Run torch.compile'd ulysses on 2 GPUs | ||
| ```bash | ||
| ./run_autosp.sh --compile compile --batch-size 1 --seq-length 64 --sp-size 2 --num-layers 1 --steps 1 --deterministic | ||
| ``` | ||
|
|
||
| #### Run torch.compile'd ring attention on 2 GPUs | ||
| ```bash | ||
| ./run_autosp.sh --compile ringattn --batch-size 1 --seq-length 64 --sp-size 2 --num-layers 1 --steps 1 --deterministic | ||
| ``` | ||
|
|
||
| ## Correctness Testing | ||
|
|
||
| To validate that AutoSP produces numerically correct results matching the baseline, use the correctness test suite: | ||
|
|
||
| ```bash | ||
| cd correctness | ||
| ./correctness.sh # Test default sp-sizes: 1, 2, 4, 8 | ||
| ./correctness.sh 2,1 # Test with custom (sp-sizes, dp_size) | ||
| ``` | ||
|
|
||
| This runs training for each configuration with both baseline (compiled Ulysses) and AutoSP modes, then compares per-rank losses to verify correctness. | ||
|
|
||
| ### Expected Output | ||
|
|
||
| When running the correctness suite with sp_size=2, you should see output similar to: | ||
|
|
||
| ``` | ||
| ================================================================ | ||
| AutoSP Correctness Test Suite | ||
| ================================================================ | ||
| Configs (sp,dp): 2,1 4,1 8,1 | ||
| Zero stages: 0 1 | ||
| Steps: 5 | ||
| Output dir: /u/ndani/DeepSpeedExamples/benchmarks/autosp/correctness/output | ||
| ================================================================ | ||
|
|
||
| ---------------------------------------------------------------- | ||
| Test: sp_size=2, dp_size=1, zero_stage=0 | ||
| ---------------------------------------------------------------- | ||
| [1/3] Running baseline (--compile compile) ... | ||
| Losses saved: 2 rank(s), 6 step(s) -> /u/ndani/DeepSpeedExamples/benchmarks/autosp/correctness/output/sp2_dp1_zero0/baseline.json | ||
| [2/3] Running autosp (--compile autosp) ... | ||
| Losses saved: 2 rank(s), 6 step(s) -> /u/ndani/DeepSpeedExamples/benchmarks/autosp/correctness/output/sp2_dp1_zero0/autosp.json | ||
| [3/3] Validating per-rank losses ... | ||
| PASS (max diff: 3.861427e-03, threshold: 1.000000e-02) | ||
|
|
||
| ---------------------------------------------------------------- | ||
| Test: sp_size=2, dp_size=1, zero_stage=1 | ||
| ---------------------------------------------------------------- | ||
| [1/3] Running baseline (--compile compile) ... | ||
| Losses saved: 2 rank(s), 6 step(s) -> /u/ndani/DeepSpeedExamples/benchmarks/autosp/correctness/output/sp2_dp1_zero1/baseline.json | ||
| [2/3] Running autosp (--compile autosp) ... | ||
| Losses saved: 2 rank(s), 6 step(s) -> /u/ndani/DeepSpeedExamples/benchmarks/autosp/correctness/output/sp2_dp1_zero1/autosp.json | ||
| [3/3] Validating per-rank losses ... | ||
| PASS (max diff: 3.166199e-03, threshold: 1.000000e-02) | ||
|
|
||
| ================================================================ | ||
| SUMMARY | ||
| ================================================================ | ||
| sp2_dp1_zero0: PASS | ||
| sp2_dp1_zero1: PASS | ||
| ``` | ||
|
|
||
| All tests should PASS with loss differences within the configurable threshold (default: 1.0e-2). | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| { | ||
|
|
||
| "bf16": { | ||
| "enabled": true | ||
| }, | ||
|
|
||
| "zero_optimization": { | ||
| "stage": 0 | ||
| }, | ||
| "compile": { | ||
| "deepcompile": true, | ||
| "passes": ["autosp"] | ||
| }, | ||
| "gradient_accumulation_steps": 1, | ||
| "gradient_clipping": "auto", | ||
| "steps_per_print": 2000, | ||
| "train_batch_size": 1, | ||
| "train_micro_batch_size_per_gpu": 1, | ||
| "wall_clock_breakdown": false, | ||
| "sequence_parallel_size": 2 | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| compute_environment: LOCAL_MACHINE | ||
| debug: false | ||
| deepspeed_config: | ||
| deepspeed_multinode_launcher: standard | ||
| deepspeed_config_file: configs/autosp_config.json | ||
| distributed_type: DEEPSPEED | ||
| machine_rank: 0 | ||
| main_training_function: main | ||
| num_machines: 1 | ||
| num_processes: 2 | ||
| rdzv_backend: static | ||
| same_network: true | ||
| tpu_env: [] | ||
| tpu_use_cluster: false | ||
| tpu_use_sudo: false | ||
| use_cpu: false |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| { | ||
| "bf16": { | ||
| "enabled": true | ||
| }, | ||
| "zero_optimization":{ | ||
| "stage": 0 | ||
| }, | ||
| "gradient_accumulation_steps": 1, | ||
| "gradient_clipping": "auto", | ||
| "steps_per_print": 2000, | ||
| "train_batch_size": "auto", | ||
| "train_micro_batch_size_per_gpu": "auto", | ||
| "wall_clock_breakdown": false | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| compute_environment: LOCAL_MACHINE | ||
| debug: false | ||
| deepspeed_config: | ||
| deepspeed_multinode_launcher: standard | ||
| deepspeed_config_file: configs/torchcompile_config.json | ||
| distributed_type: DEEPSPEED | ||
| machine_rank: 0 | ||
| main_training_function: main | ||
| num_machines: 1 | ||
| num_processes: 2 | ||
| rdzv_backend: static | ||
| same_network: true | ||
| tpu_env: [] | ||
| tpu_use_cluster: false | ||
| tpu_use_sudo: false | ||
| use_cpu: false |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| #!/bin/bash | ||
|
|
||
| # Correctness test suite for autosp vs baseline compiled DS-Ulysses. | ||
| # | ||
| # For each (sp_size, dp_size) x zero_stage configuration: | ||
| # 1. Runs baseline (--compile compile) for N steps | ||
| # 2. Runs autosp (--compile autosp) for N steps | ||
| # 3. Compares per-rank losses with validator.py | ||
| # | ||
| # Usage: | ||
| # ./correctness.sh # Default configs | ||
| # ./correctness.sh 2,1 2,2 4,1 # Custom sp,dp pairs | ||
|
|
||
| SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | ||
| OUTPUT_DIR="${SCRIPT_DIR}/output" | ||
| STEPS=5 | ||
|
|
||
| # Parse sp,dp pairs from positional args (e.g. 2,1 2,2 4,1) | ||
| declare -a CONFIGS=() | ||
|
|
||
| if [ $# -gt 0 ]; then | ||
| for arg in "$@"; do | ||
| if [[ "$arg" =~ ^([0-9]+),([0-9]+)$ ]]; then | ||
| CONFIGS+=("$arg") | ||
| else | ||
| echo "Error: invalid config '${arg}'. Expected format: sp,dp (e.g. 2,1)" | ||
| exit 1 | ||
| fi | ||
| done | ||
| else | ||
| CONFIGS=("2,1" "4,1" "8,1") | ||
| fi | ||
|
|
||
| ZERO_STAGES=(0 1) | ||
|
|
||
| PASS_COUNT=0 | ||
| FAIL_COUNT=0 | ||
| TOTAL_COUNT=0 | ||
| declare -a RESULTS=() | ||
|
|
||
| echo "" | ||
| echo "================================================================" | ||
| echo " AutoSP Correctness Test Suite" | ||
| echo "================================================================" | ||
| echo " Configs (sp,dp): ${CONFIGS[*]}" | ||
| echo " Zero stages: ${ZERO_STAGES[*]}" | ||
| echo " Steps: ${STEPS}" | ||
| echo " Output dir: ${OUTPUT_DIR}" | ||
| echo "================================================================" | ||
| echo "" | ||
|
|
||
| for config in "${CONFIGS[@]}"; do | ||
| sp_size="${config%%,*}" | ||
| dp_size="${config##*,}" | ||
|
|
||
| for zero_stage in "${ZERO_STAGES[@]}"; do | ||
| TEST_NAME="sp${sp_size}_dp${dp_size}_zero${zero_stage}" | ||
| TEST_DIR="${OUTPUT_DIR}/${TEST_NAME}" | ||
| mkdir -p "${TEST_DIR}" | ||
|
|
||
| ((TOTAL_COUNT++)) | ||
|
|
||
| echo "----------------------------------------------------------------" | ||
| echo " Test: sp_size=${sp_size}, dp_size=${dp_size}, zero_stage=${zero_stage}" | ||
| echo "----------------------------------------------------------------" | ||
|
|
||
| # --- Baseline (compiled DS-Ulysses) --- | ||
| echo " [1/3] Running baseline (--compile compile) ..." | ||
| if ! python3 "${SCRIPT_DIR}/correctness_run.py" \ | ||
| --compile compile \ | ||
| --sp-size "${sp_size}" \ | ||
| --dp-size "${dp_size}" \ | ||
| --zero-stage "${zero_stage}" \ | ||
| --steps "${STEPS}" \ | ||
| --output-file "${TEST_DIR}/baseline.json"; then | ||
|
|
||
| echo " FAIL: Baseline training failed" | ||
| RESULTS+=(" ${TEST_NAME}: FAIL (baseline training error)") | ||
| ((FAIL_COUNT++)) | ||
| echo "" | ||
| continue | ||
| fi | ||
|
|
||
| # --- AutoSP --- | ||
| echo " [2/3] Running autosp (--compile autosp) ..." | ||
| if ! python3 "${SCRIPT_DIR}/correctness_run.py" \ | ||
| --compile autosp \ | ||
| --sp-size "${sp_size}" \ | ||
| --dp-size "${dp_size}" \ | ||
| --zero-stage "${zero_stage}" \ | ||
| --steps "${STEPS}" \ | ||
| --output-file "${TEST_DIR}/autosp.json"; then | ||
|
|
||
| echo " FAIL: AutoSP training failed" | ||
| RESULTS+=(" ${TEST_NAME}: FAIL (autosp training error)") | ||
| ((FAIL_COUNT++)) | ||
| echo "" | ||
| continue | ||
| fi | ||
|
|
||
| # --- Validate --- | ||
| echo " [3/3] Validating per-rank losses ..." | ||
| if python3 "${SCRIPT_DIR}/validator.py" \ | ||
| --baseline "${TEST_DIR}/baseline.json" \ | ||
| --autosp "${TEST_DIR}/autosp.json"; then | ||
|
|
||
| RESULTS+=(" ${TEST_NAME}: PASS") | ||
| ((PASS_COUNT++)) | ||
| else | ||
| RESULTS+=(" ${TEST_NAME}: FAIL") | ||
| ((FAIL_COUNT++)) | ||
| fi | ||
|
|
||
| echo "" | ||
| done | ||
| done | ||
|
|
||
| # ---- Summary ---- | ||
| echo "================================================================" | ||
| echo " SUMMARY" | ||
| echo "================================================================" | ||
| for result in "${RESULTS[@]}"; do | ||
| echo "${result}" | ||
| done | ||
| echo "" | ||
| echo " Passed: ${PASS_COUNT}/${TOTAL_COUNT} Failed: ${FAIL_COUNT}/${TOTAL_COUNT}" | ||
| echo "================================================================" | ||
|
|
||
| if [ "${FAIL_COUNT}" -gt 0 ]; then | ||
| exit 1 | ||
| fi | ||
| exit 0 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Include screenshot/snippet of expected output, this is helpful for users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added expected output here