Skip to content

Commit

Permalink
Update demo to streaming mode
Browse files Browse the repository at this point in the history
  • Loading branch information
bofenghuang committed Apr 6, 2023
1 parent 1fc595d commit f8c7a38
Show file tree
Hide file tree
Showing 3 changed files with 293 additions and 242 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ This project is based on [LLaMA](https://github.com/facebookresearch/llama), [St

- 2023/3/29: Add instructions for deploying using [llama.cpp](https://github.com/ggerganov/llama.cpp)
- 2023/4/3: Add fine-tuning scripts for seq2seq models
- 2023/4/6: Improve the quality of the translated Alpaca dataset
- 2023/4/6: Update Gradio demo to streaming mode

## Setup

Expand All @@ -54,13 +56,12 @@ The fine-tuned instruction-following vigogne models are available on 🤗 Huggin
- Fine-tuned LLaMA-13B model: [bofenghuang/vigogne-lora-13b](https://huggingface.co/bofenghuang/vigogne-lora-13b)
- Fine-tuned LLaMA-30B model: [bofenghuang/vigogne-lora-30b](https://huggingface.co/bofenghuang/vigogne-lora-30b)
- Fine-tuned BLOOM-7B1 model: [bofenghuang/vigogne-lora-bloom-7b1](https://huggingface.co/bofenghuang/vigogne-lora-bloom-7b1)
- Fine-tuned OPT-6.7B model: [bofenghuang/vigogne-lora-opt-6.7b](https://huggingface.co/bofenghuang/vigogne-lora-opt-6.7b)

You can infer these models by using the following Google Colab Notebook.

<a href="https://colab.research.google.com/github/bofenghuang/vigogne/blob/main/infer.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

You can also run a Gradio demo using the following command:
You can also run a Gradio demo in streaming mode by using the following command:

```bash
./demo.py \
Expand Down
61 changes: 50 additions & 11 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,22 @@
Modified from: https://github.com/tloen/alpaca-lora/blob/main/generate.py
"""

import logging
import sys
from threading import Thread

import fire
import gradio as gr
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, LlamaTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, LlamaTokenizer, TextIteratorStreamer

logging.basicConfig(
format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ",
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

if torch.cuda.is_available():
device = "cuda"
Expand All @@ -24,6 +33,8 @@
except:
pass

logger.info(f"Model will be loaded on device `{device}`")


PROMPT_DICT = {
"prompt_input": (
Expand Down Expand Up @@ -96,40 +107,68 @@ def main(
def instruct(
instruction,
input=None,
streaming=True,
temperature=0.1,
no_repeat_ngram_size=3,
max_new_tokens=512,
**kwargs,
):
prompt = generate_prompt(instruction, input)
tokenized_inputs = tokenizer(prompt, return_tensors="pt")
input_ids = tokenized_inputs["input_ids"].to(device)
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device)

generation_config = GenerationConfig(
temperature=temperature,
no_repeat_ngram_size=no_repeat_ngram_size,
**kwargs,
)
with torch.inference_mode():

if streaming:
# Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
# in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
generation_config=generation_config,
# return_dict_in_generate=True,
# output_scores=True,
max_new_tokens=max_new_tokens,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()

# Pull the generated text from the streamer, and update the model output.
output_text = ""
for new_text in streamer:
output_text += new_text
yield output_text
logger.info(output_text)
return output_text

else:
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
# return_dict_in_generate=True,
# output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences[0]
output = tokenizer.decode(s, skip_special_tokens=True)
return output.split("### Réponse:")[1].strip()

output_text = tokenizer.decode(generation_output[0], skip_special_tokens=True)
logger.info(output_text)
output_text = output_text.rsplit("### Réponse:", 1)[-1].strip()
return output_text

gr.Interface(
fn=instruct,
inputs=[
gr.inputs.Textbox(label="Instruction", default="Parlez-moi des alpacas."),
gr.inputs.Textbox(label="Instruction", default="Parlez-moi des vigognes."),
gr.inputs.Textbox(label="Input"),
gr.Checkbox(label="Streaming mode?", value=True),
],
outputs=[gr.outputs.Textbox(label="Output")],
outputs=[gr.Textbox(label="Output", interactive=False)],
title="🦙 Vigogne-LoRA",
description="Vigogne-LoRA is a 7B-parameter LLaMA model finetuned to follow the instructions in French. For more information, please visit [the project's website](https://github.com/bofenghuang/vigogne).",
description="Vigogne-LoRA is a 7B-parameter LLaMA model finetuned to follow the French 🇫🇷 instructions. For more information, please visit the [Github repo](https://github.com/bofenghuang/vigogne).",
).launch(enable_queue=True, share=True)


Expand Down
Loading

0 comments on commit f8c7a38

Please sign in to comment.