Skip to content

Commit

Permalink
Merge branch 'main' into background-job
Browse files Browse the repository at this point in the history
  • Loading branch information
Ivan-Zhou committed Jan 28, 2024
2 parents 311b378 + efef064 commit 5c57b33
Show file tree
Hide file tree
Showing 19 changed files with 905 additions and 227 deletions.
1 change: 1 addition & 0 deletions config/llama2_nano.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ model:
type: llama
hidden_dim: 32
num_heads: 4
num_kv_heads: 4
num_layers: 2
trainer:
wandb:
Expand Down
389 changes: 218 additions & 171 deletions docs/Fine-Tuning.md

Large diffs are not rendered by default.

Binary file added docs/figures/finetune_func_cm_full_weight.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/figures/finetune_func_cm_lora.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
427 changes: 427 additions & 0 deletions docs/tutorials/Fine-Tuning-Semantic-Parsing.md

Large diffs are not rendered by default.

19 changes: 19 additions & 0 deletions examples/alpaca/alpaca-llama2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,22 @@ trainer:
optimizer:
learning_rate: 2e-5
weight_decay: 0.0
prompts:
# |- means multiline string, keeping all but the final newline
prompt_input: |-
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{input}
### Response:
prompt_no_input: |-
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response:
11 changes: 6 additions & 5 deletions examples/alpaca/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
# Ways this script could be improved:
# * Could tune hparams more for throughput

# Original
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -93,7 +94,7 @@ class TrainArgs:

model_cache_dir: Optional[str] = None # Path to cache the model. must be local.

hf_save_path: Optional[str] = None # Path to save the HuggingFace checkpoint.
hf_save_path: Optional[str] = "alpaca_hf_ckpts" # Path to save the HuggingFace checkpoint, can be gcs
hf_upload: Union[bool, str] = False # Name of the HuggingFace repo to upload to (if any).
hf_save_steps: int = 1000 # How often to save the HuggingFace checkpoint.

Expand Down Expand Up @@ -134,14 +135,14 @@ def _get_data_source(path_or_id):
"""The original alpaca.py used a json file, but it's since been moved to the HF dataset hub. You can use any
dataset that's compatible with the structure of the alpaca dataset."""
if fsspec_utils.exists(path_or_id):
# get file format: jsonl or json
if path_or_id.endswith(".jsonl"):
# we're a bit generous here b/c we support compression
if ".jsonl" in path_or_id:
return JsonlDataset([path_or_id])
elif path_or_id.endswith(".json"):
elif ".json" in path_or_id:
return JsonDataset([path_or_id])
else:
raise ValueError(
f"We only support HF Dataset or a data file with .json or .jsonl extensions, not {path_or_id}!"
f"We only support HF Datasets or a data file with .json or .jsonl extensions, not {path_or_id}!"
)
else:
return WrappedHFDataset(path_or_id, split="train")
Expand Down
19 changes: 19 additions & 0 deletions examples/alpaca/alpaca.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,22 @@ trainer:
optimizer:
learning_rate: 2e-5
weight_decay: 0.0
prompts:
# |- means multiline string, keeping all but the final newline
prompt_input: |-
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{input}
### Response:
prompt_no_input: |-
Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response:
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ nav:
- 'Tutorials':
- "Fine-Tuning.md"
- "LoRA.md"
- "tutorials/Fine-Tuning-Semantic-Parsing.md"
- "Hardware-Agnostic-Training.md"
- 'Developer Guide':
- 'Port-Models.md'
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ dependencies = [
"ray[default]",
"pydantic<2", # temporary pin until Ray supports pydantic 2.0
"rich>=13",
# "chex>=0.1.85"
"filelock",
]

[tool.hatch.build]
Expand Down
4 changes: 2 additions & 2 deletions src/levanter/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def save_checkpoint(self, info, destination: str):
logger.info(f"Saved checkpoint at step {info.step} to {path}. Save time is {self._last_save_time}")


def save_checkpoint(model, training_state, step: int, checkpoint_path: PathLike, *, exist_ok: bool = False):
def save_checkpoint(model, training_state, step: int, checkpoint_path: PathLike):
"""
Save a checkpoint to a given path using TensorStore. If exist_ok is True, the checkpoint
will be saved even if a checkpoint already exists at the given path.
Expand All @@ -247,7 +247,7 @@ def save_checkpoint(model, training_state, step: int, checkpoint_path: PathLike,

fs: AbstractFileSystem
fs, plain_path = _get_fs_and_plain_path(checkpoint_path)
fs.makedirs(plain_path, exist_ok=exist_ok)
fs.makedirs(plain_path, exist_ok=True)

tree_serialize_leaves_tensorstore(os.path.join(checkpoint_path, "model"), model)
if training_state is not None:
Expand Down
92 changes: 69 additions & 23 deletions src/levanter/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,29 +224,30 @@ def _munge_address_port(address: str):
# this is no longer the case, so instead we need to check if we are the coordinator
# and if so, start the head

if _is_this_machine(host):
logger.info(f"Starting ray head on port {ray_port}. We are process the coordinator {host}.")
logger.info(f"Starting ray with num_cpus set to {num_cpus}.")
ret = os.system(
f"ray start --head --port {ray_port} --num-cpus {num_cpus} --dashboard-host=0.0.0.0"
)
if ret != 0:
raise RuntimeError(f"Failed to start ray head with exit code {ret}")
else:
logger.info(f"Successfully started ray head on port {ray_port}.")

# install an atexit handler to kill the head when we exit
atexit.register(lambda: os.system("ray stop -g 10 --force"))
elif start_workers:
logger.info(
f"Starting ray worker and connecting to {address}. We are process {jax.process_index()}."
)
logger.info(f"Starting ray with num_cpus set to {num_cpus}.")
ret = os.system(f"ray start --address {address} --num-cpus {num_cpus}")
if ret != 0:
raise RuntimeError(f"Failed to start ray head with exit code {ret}")
else:
logger.info(f"Successfully started ray worker and connected to {address}.")
if _is_local_leader():
if _is_this_machine(host):
logger.info(f"Starting ray head on port {ray_port}. We are process the coordinator {host}.")
logger.info(f"Starting ray head with num_cpus set to {num_cpus}.")
ret = os.system(
f"ray start --head --port {ray_port} --num-cpus {num_cpus} --dashboard-host=0.0.0.0"
)
if ret != 0:
raise RuntimeError(f"Failed to start ray head with exit code {ret}")
else:
logger.info(f"Successfully started ray head on port {ray_port}.")

# install an atexit handler to kill the head when we exit
atexit.register(lambda: os.system("ray stop -g 10 --force"))
elif start_workers:
logger.info(
f"Starting ray worker and connecting to {address}. We are process {jax.process_index()}."
)
logger.info(f"Starting ray worker with num_cpus set to {num_cpus}.")
ret = os.system(f"ray start --address {address} --num-cpus {num_cpus}")
if ret != 0:
raise RuntimeError(f"Failed to start ray head with exit code {ret}")
else:
logger.info(f"Successfully started ray worker and connected to {address}.")

logger.info(f"ray.init(address={repr(address)}, namespace={repr(namespace)}, **{repr(kwargs)})")
# Ray has retry logic, so we don't need to retry here :fingers-crossed:
Expand Down Expand Up @@ -318,6 +319,9 @@ def _is_this_machine(host):
"""
Checks if the given host identifies this machine.
"""
if host == "localhost" or host == "0.0.0.0":
return True

try:
# Get IP addresses of all interfaces
machine_ips = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)]
Expand All @@ -330,3 +334,45 @@ def _is_this_machine(host):

# Check if the host IP matches any of the machine IPs
return any(host_ip == machine_ip for machine_ip in machine_ips)


def _remove_if_possible(path):
try:
os.remove(path)
except OSError:
pass


def _touch(file_path):
with open(file_path, "a"):
os.utime(file_path, None)


def _is_local_leader():
import atexit

import filelock
from jax.experimental.multihost_utils import broadcast_one_to_all

if jax.process_count() == 1:
return True

import random

random_id = random.randint(0, 1000000)
random_id = broadcast_one_to_all(random_id)

lock = filelock.FileLock(f"/tmp/levanter_local_process_zero_lock.{random_id}")
action_performed_file = f"/tmp/levanter_local_process_zero_action_performed.{random_id}"

try:
with lock.acquire(timeout=0.1):
if not os.path.exists(action_performed_file):
_touch(action_performed_file)
return True # Action needs to be performed
else:
return False # Action already performed
atexit.register(_remove_if_possible, lock.lock_file)
atexit.register(_remove_if_possible, action_performed_file)
except filelock.Timeout:
return False
2 changes: 0 additions & 2 deletions src/levanter/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,12 +500,10 @@ def to_hf_config(config: LoraConfig, base_model_name_or_path: Optional[str] = No
return {
"base_model_name_or_path": base_model_name_or_path,
"bias": "none", # TODO: support bias
"enable_lora": None,
"fan_in_fan_out": False, # TODO: support fan_in_fan_out
"inference_mode": True, # TODO: support inference_mode
"lora_alpha": config.alpha,
"lora_dropout": 0.00, # TODO: support dropout
"merge_weights": False,
"modules_to_save": None, # TODO: support modules_to_save?
"peft_type": "LORA",
"r": config.r,
Expand Down
32 changes: 25 additions & 7 deletions src/levanter/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class LlamaConfig(HFCompatConfig):
intermediate_dim (int, optional): dimension of the intermediate state. Defaults to 11008.
num_layers (int, optional): number of hidden layers in the Transformer encoder. Defaults to 32.
num_heads (int, optional): number of attention heads for each attention layer. Defaults to 32.
num_kv_heads (int, optional): number of attention heads for keys and values in each attention layer.
Setting to 1 means MQA. Setting to num_heads means MHA. Otherwise GQA.
Note that num_heads must be divisible by this number. Defaults to 32.
activation_function (str, optional): activation function for the hidden layer. Defaults to "silu".
rope_scaling (Dict, optional): dict containing the scaling configuration for the Rotary Positional Embedding.
"""
Expand All @@ -56,6 +59,7 @@ class LlamaConfig(HFCompatConfig):
intermediate_dim: int = 11008
num_layers: int = 32
num_heads: int = 32
num_kv_heads: int = 32
activation_function: str = "silu"
initializer_range: float = 0.02
layer_norm_epsilon: float = 1e-5
Expand All @@ -76,10 +80,16 @@ class LlamaConfig(HFCompatConfig):
KeyPos = property(lambda self: self.Pos.alias("key_position"))
Embed = property(lambda self: Axis(name="embed", size=self.hidden_dim))
Heads = property(lambda self: Axis(name="heads", size=self.num_heads))
KVHeads = property(lambda self: Axis(name="kv_heads", size=self.num_kv_heads))
Layers = property(lambda self: Axis(name="layers", size=self.num_layers))
Mlp = property(lambda self: Axis(name="mlp", size=self.intermediate_dim))
HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_dim // self.num_heads))

def __post_init__(self):
assert (
self.num_heads % self.num_kv_heads == 0
), f"num_heads={self.num_heads} not divisible by num_kv_heads={self.num_kv_heads}."

@cached_classproperty
def default_hf_checkpoint_converter(cls) -> HFCheckpointConverter["LlamaConfig"]: # type: ignore
return HFCheckpointConverter(
Expand All @@ -98,6 +108,7 @@ def from_hf_config(cls, hf_config: HfConfig):
intermediate_dim=hf_config.intermediate_size,
num_layers=hf_config.num_hidden_layers,
num_heads=hf_config.num_attention_heads,
num_kv_heads=hf_config.num_key_value_heads,
activation_function=hf_config.hidden_act,
initializer_range=hf_config.initializer_range,
layer_norm_epsilon=hf_config.rms_norm_eps,
Expand All @@ -123,6 +134,7 @@ def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None)
intermediate_size=self.intermediate_dim,
num_hidden_layers=self.num_layers,
num_attention_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
hidden_act=self.activation_function,
initializer_range=self.initializer_range,
rms_norm_eps=self.layer_norm_epsilon,
Expand Down Expand Up @@ -264,10 +276,14 @@ class LlamaAttention(StateDictSerializationMixin, eqx.Module):
def init(config: LlamaConfig, *, key) -> "LlamaAttention":
use_bias = config.use_bias
Embed = config.Embed
QHeadsPerGroup = hax.Axis("q_heads_per_group", config.num_heads // config.num_kv_heads)

k_q, k_k, k_v, k_o = jrandom.split(key, 4)
q_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_q, use_bias=use_bias)
k_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_k, use_bias=use_bias)
v_proj = hnn.Linear.init(In=Embed, Out=(config.Heads, config.HeadSize), key=k_v, use_bias=use_bias)
q_proj = hnn.Linear.init(
In=Embed, Out=(config.KVHeads, QHeadsPerGroup, config.HeadSize), key=k_q, use_bias=use_bias
)
k_proj = hnn.Linear.init(In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_k, use_bias=use_bias)
v_proj = hnn.Linear.init(In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_v, use_bias=use_bias)
o_proj = hnn.Linear.init(In=(config.Heads, config.HeadSize), Out=Embed, key=k_o, use_bias=use_bias)
rotary_emb = LlamaRotaryEmbedding(config.HeadSize, config.Pos)
return LlamaAttention(config, q_proj, k_proj, v_proj, o_proj, rotary_emb)
Expand All @@ -277,9 +293,9 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray], *, key=None) -> Na
key_q, key_k, key_v, key_o = maybe_rng_split(key, 4)

# reorder heads and position for better training throughput
q = self.q_proj(x, key=key_q).rearrange((..., "heads", "position", "head_size"))
k = self.k_proj(x, key=key_k).rearrange((..., "heads", "position", "head_size"))
v = self.v_proj(x, key=key_v).rearrange((..., "heads", "position", "head_size"))
q = self.q_proj(x, key=key_q).rearrange((..., "kv_heads", "q_heads_per_group", "position", "head_size"))
k = self.k_proj(x, key=key_k).rearrange((..., "kv_heads", "position", "head_size"))
v = self.v_proj(x, key=key_v).rearrange((..., "kv_heads", "position", "head_size"))

cos, sin = self.rotary_emb(seq_len=x.axis_size("position"))

Expand All @@ -305,6 +321,8 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray], *, key=None) -> Na
flash_block_size=c.flash_attention_block_size,
)

attn_output = attn_output.flatten_axes(("kv_heads", "q_heads_per_group"), "heads")

if self.config.upcast_attn:
attn_output = attn_output.astype(x.dtype)

Expand Down Expand Up @@ -574,7 +592,7 @@ def _rotate_half(x: NamedArray) -> NamedArray:


def _apply_rotary_pos_emb(
q: NamedArray, # [batch, position, heads, head_size]
q: NamedArray, # [batch, position, kv_heads, q_heads_per_group, head_size]
k: NamedArray, # [batch, position, kv_heads, head_size]
cos: NamedArray, # [position, head_size]
sin: NamedArray, # [position, head_size]
Expand Down
Loading

0 comments on commit 5c57b33

Please sign in to comment.