Skip to content

Commit

Permalink
Rename distillation class to DistillationLoRA
Browse files Browse the repository at this point in the history
  • Loading branch information
armbues committed Nov 26, 2024
1 parent 8a6691c commit 168fd9d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions sillm/distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sillm
import sillm.utils as utils
from sillm.experimental.distillation import Distillation
from sillm.experimental.distillation import DistillationLoRA

if __name__ == "__main__":
# Parse commandline arguments
Expand Down Expand Up @@ -62,7 +62,7 @@
distillation_config = {
"loss_alpha": args.loss_alpha
}
target_model = Distillation.from_model(target_model, draft_model, **distillation_config)
target_model = DistillationLoRA.from_model(target_model, draft_model, **distillation_config)

# Initialize LoRA layers
lora_config = {
Expand Down
4 changes: 2 additions & 2 deletions sillm/experimental/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

logger = logging.getLogger("sillm")

class Distillation(TrainableLoRA):
class DistillationLoRA(TrainableLoRA):
"""
Trainable distillation LLM.
References:
Expand All @@ -26,7 +26,7 @@ def from_model(target_llm: LLM, draft_llm: LLM, **kwargs):
Returns:
Trainable distillation LLM.
"""
model = Distillation(target_llm.model, draft_llm.model, target_llm.tokenizer, draft_llm.tokenizer, target_llm.args, **kwargs)
model = DistillationLoRA(target_llm.model, draft_llm.model, target_llm.tokenizer, draft_llm.tokenizer, target_llm.args, **kwargs)
model._quantization = target_llm._quantization

return model
Expand Down

0 comments on commit 168fd9d

Please sign in to comment.