diff --git a/examples/huggingface/Copy_of_DreamBooth_Stable_Diffusion.ipynb b/examples/huggingface/Copy_of_DreamBooth_Stable_Diffusion.ipynb new file mode 100644 index 00000000..63a87bb1 --- /dev/null +++ b/examples/huggingface/Copy_of_DreamBooth_Stable_Diffusion.ipynb @@ -0,0 +1,1287 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N9q4Kfqyg8WX" + }, + "source": [ + "# Step 0 - Connect to a virtual machine and Google Drive" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "XU7NuMAA2drw", + "outputId": "e3955d95-767b-431f-a874-90c385a2cb7d", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Tesla T4, 15360 MiB, 15101 MiB\n", + "Mounted at /content/gdrive/\n" + ] + } + ], + "source": [ + "#@markdown This is a code block. Click the run icon on the left to check type of GPU and VRAM available for you!\n", + "!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader\n", + "\n", + "from google.colab import drive\n", + "\n", + "drive.mount('/content/gdrive/', force_remount=True)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wnTMyW41cC1E" + }, + "source": [ + "# Step 1 - Install Requirements\n", + "This block will install all the Python dependencies in your environment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "aLWXPZqjsZVV", + "outputId": "11567a4a-9495-46c3-ea54-9f661df780dd", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m5.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for diffusers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m63.3/63.3 MB\u001b[0m \u001b[31m13.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m191.5/191.5 kB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m33.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.1/53.1 kB\u001b[0m \u001b[31m6.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.5/62.5 MB\u001b[0m \u001b[31m10.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20.0/20.0 MB\u001b[0m \u001b[31m55.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m106.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m77.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m65.7/65.7 kB\u001b[0m \u001b[31m7.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m297.4/297.4 kB\u001b[0m \u001b[31m24.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m75.4/75.4 kB\u001b[0m \u001b[31m8.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m50.5/50.5 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.3/140.3 kB\u001b[0m \u001b[31m15.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m45.7/45.7 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.5/59.5 kB\u001b[0m \u001b[31m6.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m129.9/129.9 kB\u001b[0m \u001b[31m15.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m50.4/50.4 kB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.5/46.5 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.7/43.7 kB\u001b[0m \u001b[31m5.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m41.0/41.0 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m41.0/41.0 kB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m87.5/87.5 kB\u001b[0m \u001b[31m10.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.5/84.5 kB\u001b[0m \u001b[31m10.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m67.0/67.0 kB\u001b[0m \u001b[31m8.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m74.5/74.5 kB\u001b[0m \u001b[31m9.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for ffmpy (setup.py) ... \u001b[?25l\u001b[?25hdone\n" + ] + } + ], + "source": [ + "!wget -q https://raw.githubusercontent.com/buildspace/diffusers/main/examples/dreambooth/train_dreambooth.py\n", + "!wget -q https://github.com/ShivamShrirao/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py\n", + "%pip install -qq git+https://github.com/ShivamShrirao/diffusers\n", + "%pip install -q -U --pre triton\n", + "%pip install -q accelerate==0.15.0 transformers ftfy bitsandbytes==0.35.0 gradio natsort" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "w1Nn3CL4fUjp" + }, + "source": [ + "# Step 2 - Login to HuggingFace 🤗" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "y4lqqWT_uxD2" + }, + "outputs": [], + "source": [ + "#@markdown We're gonna be loading the Stable Diffusion model from HuggingFace. Head over to [HuggingFace](huggingface.co?ref=buildspace) and sign up for a free account.\n", + "\n", + "#@markdown Once you're done, head over to the [tokens page](https://huggingface.co/settings/tokens) in the settings and create a read only token\n", + "#@markdown We're going to be using Stable Diffusion v1.5 from Runway, so make sure you check out the license in the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5).\n", + "\n", + "!mkdir -p ~/.huggingface\n", + "HUGGINGFACE_TOKEN = \"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\" #@param {type:\"string\"}\n", + "!echo -n \"{HUGGINGFACE_TOKEN}\" > ~/.huggingface/token" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XfTlc8Mqb8iH" + }, + "source": [ + "# Step 3 - Install xformers from precompiled wheel.\n", + "xformers are another dependency that we'll need - these are used for language processing and text classification.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "n6dcjPnnaiCn", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "3f61e778-b267-47bd-afe1-e60a2da6ed16" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[31mERROR: xformers-0.0.15.dev0+4c06c79.d20221205-cp38-cp38-linux_x86_64.whl is not a supported wheel on this platform.\u001b[0m\u001b[31m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "%pip install -q https://github.com/brian6091/xformers-wheels/releases/download/0.0.15.dev0%2B4c06c79/xformers-0.0.15.dev0+4c06c79.d20221205-cp38-cp38-linux_x86_64.whl\n", + "# These were compiled on Tesla T4.\n", + "\n", + "# If precompiled wheels don't work, install it with the following command. It will take around 40 minutes to compile.\n", + "# %pip install git+https://github.com/facebookresearch/xformers@4c06c79#egg=xformers" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "G0NV324ZcL9L" + }, + "source": [ + "# Step 4 - Configure your model\n", + "As mentioned, we're going to be using HuggingFace to load the Stable Diffusion model. There's lots of differently tuned SD models on HF, we're going to stick with the standard v1.5 released by Runway.\n", + "\n", + "The way you choose is a model is by putting in the path of the URL on HuggingFace. So `https://huggingface.co/runwayml/stable-diffusion-v1-5` becomes `runwayml/stable-diffusion-v1-5`.\n", + "\n", + "You're welcome to try other versions, but we've only tested this on v1.5 and v2.1!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "Rxg0y5MBudmd", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "c9983bfe-ebf2-4e37-c701-01e1b295a095" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n", + "[*] Weights will be saved at /content/drive/MyDrive/stable_diffusion_weights/sd_v1_5\n" + ] + } + ], + "source": [ + "#@markdown Check this if you want to save the weights in your Google Drive (takes around 4-5 GB).\n", + "#@markdown Definitely do this or your model will poof when you disconnect the Colab.\n", + "#@markdown Once we upload it to HuggingFace, it'll be saved in two places - Gdrive + HuggingFace\n", + "\n", + "save_to_gdrive = True #@param {type:\"boolean\"}\n", + "if save_to_gdrive:\n", + " from google.colab import drive\n", + " drive.mount('/content/drive')\n", + "\n", + "#@markdown Name/path of the initial model.\n", + "MODEL_NAME = \"runwayml/stable-diffusion-v1-5\" #@param {type:\"string\"}\n", + "\n", + "#@markdown Enter the Gdrive directory to save model at.\n", + "\n", + "OUTPUT_DIR = \"stable_diffusion_weights/sd_v1_5\" #@param {type:\"string\"}\n", + "if save_to_gdrive:\n", + " OUTPUT_DIR = \"/content/drive/MyDrive/\" + OUTPUT_DIR\n", + "else:\n", + " OUTPUT_DIR = \"/content/\" + OUTPUT_DIR\n", + "\n", + "print(f\"[*] Weights will be saved at {OUTPUT_DIR}\")\n", + "\n", + "!mkdir -p $OUTPUT_DIR" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qn5ILIyDJIcX" + }, + "source": [ + "# Step 5 - Configure the training resources\n", + "We won't need to touch any of this, but it's here if you want to come back and try turning the knobs once you understand this stuff!\n", + "\n", + "Use the table below to choose the best flags based on your memory and speed requirements. Tested on Tesla T4 GPU.\n", + "\n", + "\n", + "| `fp16` | `train_batch_size` | `gradient_accumulation_steps` | `gradient_checkpointing` | `use_8bit_adam` | GB VRAM usage | Speed (it/s) |\n", + "| ---- | ------------------ | ----------------------------- | ----------------------- | --------------- | ---------- | ------------ |\n", + "| fp16 | 1 | 1 | TRUE | TRUE | 9.92 | 0.93 |\n", + "| no | 1 | 1 | TRUE | TRUE | 10.08 | 0.42 |\n", + "| fp16 | 2 | 1 | TRUE | TRUE | 10.4 | 0.66 |\n", + "| fp16 | 1 | 1 | FALSE | TRUE | 11.17 | 1.14 |\n", + "| no | 1 | 1 | FALSE | TRUE | 11.17 | 0.49 |\n", + "| fp16 | 1 | 2 | TRUE | TRUE | 11.56 | 1 |\n", + "| fp16 | 2 | 1 | FALSE | TRUE | 13.67 | 0.82 |\n", + "| fp16 | 1 | 2 | FALSE | TRUE | 13.7 | 0.83 |\n", + "| fp16 | 1 | 1 | TRUE | FALSE | 15.79 | 0.77 |\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-ioxxvHoicPs" + }, + "source": [ + "Add `--gradient_checkpointing` flag for around 9.92 GB VRAM usage.\n", + "\n", + "remove `--use_8bit_adam` flag for full precision. Requires 15.79 GB with `--gradient_checkpointing` else 17.8 GB.\n", + "\n", + "remove `--train_text_encoder` flag to reduce memory usage further, degrades output quality." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "g9USRfYHkCeL" + }, + "source": [ + "## Step 5.5 - Tell Stable Diffusion what you're turning for\n", + "Here's where you tell Stable diffusion *what* you're tuning for.\n", + "\n", + "**Instance prompt**: this describes exactly what your images are of. In our case it's whatever we decided as the name (\"abraza\" for me) and \"man/woman/person\". This is the **label** for the images we uploaded.\n", + "\n", + "**Class prompt**: this just describes what else Stable Diffusion should relate your model to. \"man\", \"woman\" or \"person\" works :)\n", + "\n", + "All you need to do is put your unique identifier (\"abraza\") here and whatever the class is right here. Make sure you run both blocks!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TxueducWGeij" + }, + "outputs": [], + "source": [ + "INSTANCE_NAME = 'tree' #@param {type:\"string\"}\n", + "CLASS_NAME = 'living organism' #@param {type:\"string\"}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5vDpCxId1aCm" + }, + "outputs": [], + "source": [ + "# You can also add multiple concepts here. Try tweaking `--max_train_steps` accordingly.\n", + "\n", + "concepts_list = [\n", + " {\n", + " \"instance_prompt\": f\"photo of {INSTANCE_NAME} {CLASS_NAME}\",\n", + " \"class_prompt\": f\"photo of a {CLASS_NAME}\",\n", + " \"instance_data_dir\": f\"/content/data/{INSTANCE_NAME}\",\n", + " \"class_data_dir\": f\"/content/data/{CLASS_NAME}\"\n", + " },\n", + "# {\n", + "# \"instance_prompt\": \"photo of ukj person\",\n", + "# \"class_prompt\": \"photo of a person\",\n", + "# \"instance_data_dir\": \"/content/data/ukj\",\n", + "# \"class_data_dir\": \"/content/data/person\"\n", + "# }\n", + "]\n", + "\n", + "# `class_data_dir` contains regularization images\n", + "import json\n", + "import os\n", + "for c in concepts_list:\n", + " os.makedirs(c[\"instance_data_dir\"], exist_ok=True)\n", + "\n", + "with open(\"concepts_list.json\", \"w\") as f:\n", + " json.dump(concepts_list, f, indent=4)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RRXKAhWrk0Qj" + }, + "source": [ + "# Step 6 - Upload your images" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "32gYIDDR1aCp", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 90 + }, + "outputId": "f1ff9d3a-9a7c-4917-e724-eca9ca2153ad" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Uploading instance images for `photo of tree living organism`\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " \n", + " Upload widget is only available when the cell has been executed in the\n", + " current browser session. Please rerun this cell to enable.\n", + " \n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Saving tree.png to tree.png\n" + ] + } + ], + "source": [ + "#@markdown Upload your images by running this cell (recommended).\n", + "\n", + "#@markdown Run this block and the \"choose files\" button will pop up. Remember - no more than 10 pictures!\n", + "\n", + "#@markdown OR\n", + "\n", + "#@markdown You can use the file manager on the left panel to upload (drag and drop) to each `instance_data_dir` (it uploads faster)\n", + "\n", + "import os\n", + "from google.colab import files\n", + "import shutil\n", + "\n", + "for c in concepts_list:\n", + " print(f\"Uploading instance images for `{c['instance_prompt']}`\")\n", + " uploaded = files.upload()\n", + " for filename in uploaded.keys():\n", + " dst_path = os.path.join(c['instance_data_dir'], filename)\n", + " shutil.move(filename, dst_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HRCp7DxllTVA" + }, + "source": [ + "# Step 7 - Configure the training options and TRAIN!\n", + "Okay this may seem intimidating, but you don't have to touch most of it!\n", + "\n", + "Again, I've left these in here if you really know what you're doing and want to customise your model, for your first time all you need to do is:\n", + "1. Change `max_train_steps`. You wanna keep this number lower than 2000 - the higher it goes, the longer training takes and the more \"familiar\" SD becomes with you. Keep this number small to avoid overfitting. The general rule of thumb here is 100 steps for each picture+ a base of100. So for 6 pictures, just set it to 700!\n", + "2. **Update `save_sample_prompt` to a prompt with your subject.** Right after training, this block will generate 4 images of you with this prompt. I recommend spazzing it up a bit more than just \"Photo of xyz person\", those come out quite boring. Put those prompting skills to use!\n", + "\n", + "### This will take ~20m to run!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jjcSXTp-u-Eg", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "8090d189-ea7f-42b3-965a-191f5f6b0ae6" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Traceback (most recent call last):\n", + " File \"/usr/local/bin/accelerate\", line 5, in \n", + " from accelerate.commands.accelerate_cli import main\n", + " File \"/usr/local/lib/python3.10/dist-packages/accelerate/__init__.py\", line 7, in \n", + " from .accelerator import Accelerator\n", + " File \"/usr/local/lib/python3.10/dist-packages/accelerate/accelerator.py\", line 27, in \n", + " from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state\n", + " File \"/usr/local/lib/python3.10/dist-packages/accelerate/checkpointing.py\", line 24, in \n", + " from .utils import (\n", + " File \"/usr/local/lib/python3.10/dist-packages/accelerate/utils/__init__.py\", line 103, in \n", + " from .megatron_lm import (\n", + " File \"/usr/local/lib/python3.10/dist-packages/accelerate/utils/megatron_lm.py\", line 32, in \n", + " from transformers.modeling_outputs import (\n", + " File \"/usr/local/lib/python3.10/dist-packages/transformers/__init__.py\", line 26, in \n", + " from . import dependency_versions_check\n", + " File \"/usr/local/lib/python3.10/dist-packages/transformers/dependency_versions_check.py\", line 57, in \n", + " require_version_core(deps[pkg])\n", + " File \"/usr/local/lib/python3.10/dist-packages/transformers/utils/versions.py\", line 117, in require_version_core\n", + " return require_version(requirement, hint)\n", + " File \"/usr/local/lib/python3.10/dist-packages/transformers/utils/versions.py\", line 111, in require_version\n", + " _compare_versions(op, got_ver, want_ver, requirement, pkg, hint)\n", + " File \"/usr/local/lib/python3.10/dist-packages/transformers/utils/versions.py\", line 44, in _compare_versions\n", + " raise ImportError(\n", + "ImportError: accelerate>=0.20.3 is required for a normal functioning of this module, but found accelerate==0.15.0.\n", + "Try: pip install transformers -U or pip install -e '.[dev]' if you're working with git main\n" + ] + } + ], + "source": [ + "!accelerate launch train_dreambooth.py \\\n", + " --pretrained_model_name_or_path=$MODEL_NAME \\\n", + " --pretrained_vae_name_or_path=\"stabilityai/sd-vae-ft-mse\" \\\n", + " --output_dir=$OUTPUT_DIR \\\n", + " --revision=\"fp16\" \\\n", + " --with_prior_preservation --prior_loss_weight=1.0 \\\n", + " --seed=1337 \\\n", + " --resolution=512 \\\n", + " --train_batch_size=1 \\\n", + " --train_text_encoder \\\n", + " --mixed_precision=\"fp16\" \\\n", + " --use_8bit_adam \\\n", + " --gradient_accumulation_steps=1 \\\n", + " --learning_rate=1e-6 \\\n", + " --lr_scheduler=\"constant\" \\\n", + " --lr_warmup_steps=0 \\\n", + " --num_class_images=50 \\\n", + " --sample_batch_size=4 \\\n", + " --max_train_steps=500 \\\n", + " --save_interval=10000 \\\n", + " --save_sample_prompt=\"Photo of {INSTANCE_NAME} {CLASS_NAME}, highly detailed, 8k, uhd, studio lighting, beautiful\" \\\n", + " --concepts_list=\"concepts_list.json\"\n", + "\n", + "# Reduce the `--save_interval` to lower than `--max_train_steps` to save weights from intermediate steps.\n", + "# `--save_sample_prompt` can be same as `--instance_prompt` to generate intermediate samples (saved along with weights in samples directory)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "89Az5NUxOWdy", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "b4907a75-8b67-40d4-bb41-838501a0f0d4" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[*] WEIGHTS_DIR=/content/drive/MyDrive/stable_diffusion_weights/sd_v1_5/500\n" + ] + } + ], + "source": [ + "#@markdown ## Step 7.2 - Set weights (run without changes first time)\n", + "#@markdown Specify which tuned weights you want to use - you only need to change this if you're generating with an existing tuned model. **Leave it blank and it'll use the latest weights (SD v1.5).**\n", + "\n", + "#@markdown This is a Google Drive path, so if you wanna change it, it should look something like\n", + "#@markdown `/content/drive/MyDrive/stable_diffusion_weights/raza/FOLDER_WITH_CKPT_FILE`\n", + "WEIGHTS_DIR = \"\" #@param {type:\"string\"}\n", + "if WEIGHTS_DIR == \"\":\n", + " from natsort import natsorted\n", + " from glob import glob\n", + " import os\n", + " WEIGHTS_DIR = natsorted(glob(OUTPUT_DIR + os.sep + \"*\"))[-1]\n", + "print(f\"[*] WEIGHTS_DIR={WEIGHTS_DIR}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "HMy4p4aNnG--", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 407 + }, + "outputId": "866b2e97-80f4-411e-a511-5597d83271fd" + }, + "outputs": [ + { + "output_type": "error", + "ename": "ValueError", + "evalue": "ignored", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mcol\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlistdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweights_folder\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfolders\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"samples\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mscale\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0mfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msubplots\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrow\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcol\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfigsize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcol\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mscale\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrow\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mscale\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgridspec_kw\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m'hspace'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'wspace'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfolder\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfolders\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/matplotlib/pyplot.py\u001b[0m in \u001b[0;36msubplots\u001b[0;34m(nrows, ncols, sharex, sharey, squeeze, width_ratios, height_ratios, subplot_kw, gridspec_kw, **fig_kw)\u001b[0m\n\u001b[1;32m 1500\u001b[0m \"\"\"\n\u001b[1;32m 1501\u001b[0m \u001b[0mfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfigure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mfig_kw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1502\u001b[0;31m axs = fig.subplots(nrows=nrows, ncols=ncols, sharex=sharex, sharey=sharey,\n\u001b[0m\u001b[1;32m 1503\u001b[0m \u001b[0msqueeze\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msubplot_kw\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msubplot_kw\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1504\u001b[0m \u001b[0mgridspec_kw\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgridspec_kw\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mheight_ratios\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mheight_ratios\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/matplotlib/figure.py\u001b[0m in \u001b[0;36msubplots\u001b[0;34m(self, nrows, ncols, sharex, sharey, squeeze, width_ratios, height_ratios, subplot_kw, gridspec_kw)\u001b[0m\n\u001b[1;32m 903\u001b[0m \u001b[0mgridspec_kw\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'width_ratios'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwidth_ratios\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 904\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 905\u001b[0;31m \u001b[0mgs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_gridspec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnrows\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mncols\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfigure\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mgridspec_kw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 906\u001b[0m axs = gs.subplots(sharex=sharex, sharey=sharey, squeeze=squeeze,\n\u001b[1;32m 907\u001b[0m subplot_kw=subplot_kw)\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/matplotlib/figure.py\u001b[0m in \u001b[0;36madd_gridspec\u001b[0;34m(self, nrows, ncols, **kwargs)\u001b[0m\n\u001b[1;32m 1525\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1526\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'figure'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# pop in case user has added this...\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1527\u001b[0;31m \u001b[0mgs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mGridSpec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnrows\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnrows\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mncols\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mncols\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfigure\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1528\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mgs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1529\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/matplotlib/gridspec.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, nrows, ncols, figure, left, bottom, right, top, wspace, hspace, width_ratios, height_ratios)\u001b[0m\n\u001b[1;32m 377\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfigure\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfigure\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 378\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 379\u001b[0;31m super().__init__(nrows, ncols,\n\u001b[0m\u001b[1;32m 380\u001b[0m \u001b[0mwidth_ratios\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwidth_ratios\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 381\u001b[0m height_ratios=height_ratios)\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/matplotlib/gridspec.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, nrows, ncols, height_ratios, width_ratios)\u001b[0m\n\u001b[1;32m 50\u001b[0m f\"Number of rows must be a positive integer, not {nrows!r}\")\n\u001b[1;32m 51\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mncols\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mIntegral\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mncols\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 52\u001b[0;31m raise ValueError(\n\u001b[0m\u001b[1;32m 53\u001b[0m f\"Number of columns must be a positive integer, not {ncols!r}\")\n\u001b[1;32m 54\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nrows\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_ncols\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnrows\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mncols\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mValueError\u001b[0m: Number of columns must be a positive integer, not 0" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ] + }, + "metadata": {} + } + ], + "source": [ + "#@markdown ## Step 7.3 - Generate test images!\n", + "#@markdown Run to generate a grid of preview images from the last saved weights.\n", + "import os\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.image as mpimg\n", + "\n", + "weights_folder = OUTPUT_DIR\n", + "folders = sorted([f for f in os.listdir(weights_folder) if f != \"0\"], key=lambda x: int(x))\n", + "\n", + "row = len(folders)\n", + "col = len(os.listdir(os.path.join(weights_folder, folders[0], \"sampl\n", + "\n", + "for i, folder in enumerate(folders):\n", + " folder_path = os.path.join(weights_folder, folder)\n", + " image_folder = os.path.join(folder_path, \"samples\")\n", + " images = [f for f in os.listdir(image_folder)]\n", + " for j, image in enumerate(images):\n", + " if row == 1:\n", + " currAxes = axes[j]\n", + " else:\n", + " currAxes = axes[i, j]\n", + " if i == 0:\n", + " currAxes.set_title(f\"Image {j}\")\n", + " if j == 0:\n", + " currAxes.text(-0.1, 0.5, folder, rotation=0, va='center', ha='center', transform=currAxes.transAxes)\n", + " image_path = os.path.join(image_folder, image)\n", + " img = mpimg.imread(image_path)\n", + " currAxes.imshow(img, cmap='gray')\n", + " currAxes.axis('off')\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig('grid.png', dpi=72)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5V8wgU0HN-Kq" + }, + "source": [ + "# Step 8 - Convert weights to CKPT\n", + "Since we want to use our fancy new tuned SD with web apps, we'll have to convert it to CKPT.\n", + "\n", + "If you want to save the CKPT file to your GDrive for the future and are running out of space, you can convert it to fp16, which halves the size but also severely degrades the quality. I recommend **leaving it unchecked** cause we're just going to upload it to HuggingFace at the end." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "dcXzsUyG1aCy" + }, + "outputs": [], + "source": [ + "#@markdown Run this block to start the conversion (necessary)\n", + "ckpt_path = WEIGHTS_DIR + \"/model.ckpt\"\n", + "#@markdown ----\n", + "#@markdown Check this box to convert to fp16, takes half the space (2GB). Not necessary and not recommmended.\n", + "half_arg = \"\"\n", + "fp16 = False #@param {type: \"boolean\"}\n", + "\n", + "if fp16:\n", + " half_arg = \"--half\"\n", + "!python convert_diffusers_to_original_stable_diffusion.py --model_path $WEIGHTS_DIR --checkpoint_path $ckpt_path $half_arg\n", + "print(f\"[*] Converted ckpt saved at {ckpt_path}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ToNG4fd_dTbF" + }, + "source": [ + "# Step 9 - Inference\n", + "Alrighty! We're ready to get cranking! This block will prepare the newly trained and converted model for the textual prompts used for image generation.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gW15FjffdTID" + }, + "outputs": [], + "source": [ + "import torch\n", + "from torch import autocast\n", + "from diffusers import StableDiffusionPipeline, DDIMScheduler\n", + "from IPython.display import display\n", + "\n", + "model_path = WEIGHTS_DIR # If you want to use previously trained model saved in gdrive, replace this with the full path of model in gdrive\n", + "\n", + "scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False, set_alpha_to_one=False)\n", + "pipe = StableDiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16).to(\"cuda\")\n", + "\n", + "g_cuda = None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "oIzkltjpVO_f" + }, + "outputs": [], + "source": [ + "#@markdown Can set random seed here for reproducibility.\n", + "g_cuda = torch.Generator(device='cuda')\n", + "seed = 52362 #@param {type:\"number\"}\n", + "g_cuda.manual_seed(seed)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nM-L5Dl8vcly" + }, + "source": [ + "# Step 10 - Generate images!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "K6xoHWSsbcS3", + "scrolled": false + }, + "outputs": [], + "source": [ + "#@markdown Add your name in the prompt, configure steps and scale, and make magic happen!\n", + "\n", + "prompt = \"YOUR_NAME_HERE intricate character portrait, intricate, beautiful, 8k resolution, dynamic lighting, hyperdetailed, quality 3D rendered, volumetric lighting, greg rutkowski, detailed background, artstation character portrait, dnd character portrait\" #@param {type:\"string\"}\n", + "negative_prompt = \"duplication, smile\" #@param {type:\"string\"}\n", + "num_samples = 2 #@param {type:\"number\"}\n", + "guidance_scale = 7.5 #@param {type:\"number\"}\n", + "num_inference_steps = 50 #@param {type:\"number\"}\n", + "height = 512 #@param {type:\"number\"}\n", + "width = 512 #@param {type:\"number\"}\n", + "\n", + "with autocast(\"cuda\"), torch.inference_mode():\n", + " images = pipe(\n", + " prompt,\n", + " height=height,\n", + " width=width,\n", + " negative_prompt=negative_prompt,\n", + " num_images_per_prompt=num_samples,\n", + " num_inference_steps=num_inference_steps,\n", + " guidance_scale=guidance_scale,\n", + " generator=g_cuda\n", + " ).images\n", + "\n", + "for img in images:\n", + " display(img)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "WMCqQ5Tcdsm2" + }, + "outputs": [], + "source": [ + "#@markdown Optional - Run Gradio UI for generating images. This will set up a fancy web UI that you can use for prompting instead of the ugly colab inputs.\n", + "import gradio as gr\n", + "\n", + "def inference(prompt, negative_prompt, num_samples, height=512, width=512, num_inference_steps=50, guidance_scale=7.5):\n", + " with torch.autocast(\"cuda\"), torch.inference_mode():\n", + " return pipe(\n", + " prompt, height=int(height), width=int(width),\n", + " negative_prompt=negative_prompt,\n", + " num_images_per_prompt=int(num_samples),\n", + " num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale,\n", + " generator=g_cuda\n", + " ).images\n", + "\n", + "with gr.Blocks() as demo:\n", + " with gr.Row():\n", + " with gr.Column():\n", + " prompt = gr.Textbox(label=\"Prompt\", value=\"photo of {INSTANCE_NAME} {CLASS_NAME} in a bucket\")\n", + " negative_prompt = gr.Textbox(label=\"Negative Prompt\", value=\"\")\n", + " run = gr.Button(value=\"Generate\")\n", + " with gr.Row():\n", + " num_samples = gr.Number(label=\"Number of Samples\", value=4)\n", + " guidance_scale = gr.Number(label=\"Guidance Scale\", value=7.5)\n", + " with gr.Row():\n", + " height = gr.Number(label=\"Height\", value=512)\n", + " width = gr.Number(label=\"Width\", value=512)\n", + " num_inference_steps = gr.Slider(label=\"Steps\", value=50)\n", + " with gr.Column():\n", + " gallery = gr.Gallery()\n", + "\n", + " run.click(inference, inputs=[prompt, negative_prompt, num_samples, height, width, num_inference_steps, guidance_scale], outputs=gallery)\n", + "\n", + "demo.launch(debug=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fkM9V-tYxMv7" + }, + "source": [ + "# Step 11 - Upload your custom trained model to HuggingFace\n", + "The final routine of this magic trick - putting your model on HuggingFace so you get handy inference endpoints.\n", + "\n", + "You can do this by downloading all your files from Google Drive and manually creating a HuggingFace project but I wanted to save you the time so this block here will do it all for you :D" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "PYQsyd-SI85R" + }, + "outputs": [], + "source": [ + "from slugify import slugify\n", + "from huggingface_hub import HfApi, HfFolder, CommitOperationAdd\n", + "from huggingface_hub import create_repo\n", + "from IPython.display import display_markdown\n", + "from IPython.display import clear_output\n", + "from IPython.utils import capture\n", + "from google.colab import files\n", + "import shutil\n", + "import time\n", + "import os\n", + "\n", + "Upload_sample_images = False #@param {type:\"boolean\"}\n", + "#@markdown - Upload showcase images of your trained model\n", + "\n", + "Name_of_your_concept = \"YOUR CONCEPT NAME\" #@param {type:\"string\"}\n", + "if(Name_of_your_concept == \"\"):\n", + " # should update this to be like the SESSION_NAME from the other notebook.\n", + " # Some sort of tracable name throughout\n", + " Name_of_your_concept = INSTANCE_NAME\n", + "Name_of_your_concept=Name_of_your_concept.replace(\" \",\"-\")\n", + "\n", + "Save_concept_to = \"My_Profile\" #@param [\"Public_Library\", \"My_Profile\"]\n", + "\n", + "#@markdown - [Create a write access token](https://huggingface.co/settings/tokens) , go to \"New token\" -> Role : Write. A regular read token won't work here.\n", + "hf_token_write = \"hf_xxxxxxxxxxxx\" #@param {type:\"string\"}\n", + "if hf_token_write ==\"\":\n", + " print('\u001b[1;32mYour Hugging Face write access token : ')\n", + " hf_token_write=input()\n", + "\n", + "hf_token = hf_token_write\n", + "\n", + "api = HfApi()\n", + "your_username = api.whoami(token=hf_token)[\"name\"]\n", + "\n", + "if(Save_concept_to == \"Public_Library\"):\n", + " repo_id = f\"sd-dreambooth-library/{slugify(Name_of_your_concept)}\"\n", + " #Join the Concepts Library organization if you aren't part of it already\n", + " !curl -X POST -H 'Authorization: Bearer '$hf_token -H 'Content-Type: application/json' https://huggingface.co/organizations/sd-dreambooth-library/share/SSeOwppVCscfTEzFGQaqpfcjukVeNrKNHX\n", + "else:\n", + " repo_id = f\"{your_username}/{slugify(Name_of_your_concept)}\"\n", + "MDLPTH=str(WEIGHTS_DIR+\"/model.ckpt\")\n", + "\n", + "def bar(prg):\n", + " br=\"\u001b[1;33mUploading to HuggingFace : \" '\u001b[0m|'+'█' * prg + ' ' * (25-prg)+'| ' +str(prg*4)+ \"%\"\n", + " return br\n", + "\n", + "print(\"\u001b[1;32mLoading...\")\n", + "\n", + "NM=\"False\"\n", + "if os.path.getsize(WEIGHTS_DIR+\"/text_encoder/pytorch_model.bin\") > 670901463:\n", + " NM=\"True\"\n", + "\n", + "\n", + "if NM==\"False\":\n", + " with capture.capture_output() as cap:\n", + " %cd $WEIGHTS_DIR\n", + " !rm -r safety_checker feature_extractor .git\n", + " !rm model_index.json\n", + " !git init\n", + " !git lfs install --system --skip-repo\n", + " !git remote add -f origin \"https://USER:{hf_token}@huggingface.co/runwayml/stable-diffusion-v1-5\"\n", + " !git config core.sparsecheckout true\n", + " !echo -e \"feature_extractor\\nsafety_checker\\nmodel_index.json\" > .git/info/sparse-checkout\n", + " !git pull origin main\n", + " !rm -r .git\n", + " %cd /content\n", + "\n", + "image_string = \"\"\n", + "\n", + "if os.path.exists('/content/sample_images'):\n", + " !rm -r /content/sample_images\n", + "Samples=\"/content/sample_images\"\n", + "!mkdir $Samples\n", + "clear_output()\n", + "\n", + "if Upload_sample_images:\n", + "\n", + " print(\"\u001b[1;32mUpload Sample images of the model\")\n", + " uploaded = files.upload()\n", + " for filename in uploaded.keys():\n", + " shutil.move(filename, Samples)\n", + " %cd $Samples\n", + " !find . -name \"* *\" -type f | rename 's/ /_/g'\n", + " %cd /content\n", + " clear_output()\n", + "\n", + " print(bar(1))\n", + "\n", + " images_upload = os.listdir(Samples)\n", + " instance_prompt_list = []\n", + " for i, image in enumerate(images_upload):\n", + " image_string = f'''\n", + " {image_string}![{i}](https://huggingface.co/{repo_id}/resolve/main/sample_images/{image})\n", + " '''\n", + "\n", + "readme_text = f'''---\n", + "license: creativeml-openrail-m\n", + "tags:\n", + "- text-to-image\n", + "- stable-diffusion\n", + "---\n", + "### {Name_of_your_concept} Dreambooth model trained by {api.whoami(token=hf_token)[\"name\"]} with [buildspace's DreamBooth](https://colab.research.google.com/github/buildspace/diffusers/blob/main/examples/dreambooth/DreamBooth_Stable_Diffusion.ipynb) notebook\n", + "\n", + "Build your own using the [AI Avatar project](https://buildspace.so/builds/ai-avatar)!\n", + "\n", + "To get started head over to the [project dashboard](https://buildspace.so/p/build-ai-avatars).\n", + "\n", + "Sample pictures of this concept:\n", + "{image_string}\n", + "'''\n", + "#Save the readme to a file\n", + "readme_file = open(\"README.md\", \"w\")\n", + "readme_file.write(readme_text)\n", + "readme_file.close()\n", + "\n", + "operations = [\n", + " CommitOperationAdd(path_in_repo=\"README.md\", path_or_fileobj=\"README.md\"),\n", + " CommitOperationAdd(path_in_repo=f\"model.ckpt\",path_or_fileobj=MDLPTH)\n", + "\n", + "]\n", + "create_repo(repo_id,private=True, token=hf_token)\n", + "\n", + "api.create_commit(\n", + " repo_id=repo_id,\n", + " operations=operations,\n", + " commit_message=f\"Upload the concept {Name_of_your_concept} embeds and token\",\n", + " token=hf_token\n", + ")\n", + "\n", + "if NM==\"False\":\n", + " api.upload_folder(\n", + " folder_path=WEIGHTS_DIR+\"/feature_extractor\",\n", + " path_in_repo=\"feature_extractor\",\n", + " repo_id=repo_id,\n", + " token=hf_token\n", + " )\n", + "\n", + "clear_output()\n", + "print(bar(4))\n", + "\n", + "if NM==\"False\":\n", + " api.upload_folder(\n", + " folder_path=WEIGHTS_DIR+\"/safety_checker\",\n", + " path_in_repo=\"safety_checker\",\n", + " repo_id=repo_id,\n", + " token=hf_token\n", + " )\n", + "\n", + "clear_output()\n", + "print(bar(8))\n", + "\n", + "\n", + "api.upload_folder(\n", + " folder_path=WEIGHTS_DIR+\"/scheduler\",\n", + " path_in_repo=\"scheduler\",\n", + " repo_id=repo_id,\n", + " token=hf_token\n", + ")\n", + "\n", + "clear_output()\n", + "print(bar(9))\n", + "\n", + "api.upload_folder(\n", + " folder_path=WEIGHTS_DIR+\"/text_encoder\",\n", + " path_in_repo=\"text_encoder\",\n", + " repo_id=repo_id,\n", + " token=hf_token\n", + ")\n", + "\n", + "clear_output()\n", + "print(bar(12))\n", + "\n", + "api.upload_folder(\n", + " folder_path=WEIGHTS_DIR+\"/tokenizer\",\n", + " path_in_repo=\"tokenizer\",\n", + " repo_id=repo_id,\n", + " token=hf_token\n", + ")\n", + "\n", + "clear_output()\n", + "print(bar(13))\n", + "\n", + "api.upload_folder(\n", + " folder_path=WEIGHTS_DIR+\"/unet\",\n", + " path_in_repo=\"unet\",\n", + " repo_id=repo_id,\n", + " token=hf_token\n", + ")\n", + "\n", + "clear_output()\n", + "print(bar(21))\n", + "\n", + "api.upload_folder(\n", + " folder_path=WEIGHTS_DIR+\"/vae\",\n", + " path_in_repo=\"vae\",\n", + " repo_id=repo_id,\n", + " token=hf_token\n", + ")\n", + "\n", + "clear_output()\n", + "print(bar(23))\n", + "\n", + "api.upload_file(\n", + " path_or_fileobj=WEIGHTS_DIR+\"/model_index.json\",\n", + " path_in_repo=\"model_index.json\",\n", + " repo_id=repo_id,\n", + " token=hf_token\n", + ")\n", + "\n", + "clear_output()\n", + "print(bar(24))\n", + "\n", + "api.upload_folder(\n", + " folder_path=Samples,\n", + " path_in_repo=\"sample_images\",\n", + " repo_id=repo_id,\n", + " token=hf_token\n", + ")\n", + "\n", + "clear_output()\n", + "print(bar(25))\n", + "\n", + "display_markdown(f'''## Your concept was saved successfully. [Click here to access it](https://huggingface.co/{repo_id})\n", + "''', raw=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "lJoOgLQHnC8L" + }, + "outputs": [], + "source": [ + "#@title (Optional) Delete diffuser and old weights and only keep the ckpt to free up drive space.\n", + "\n", + "#@markdown [ ! ] Caution, Only execute if you are sure u want to delete the diffuser format weights and only use the ckpt.\n", + "import shutil\n", + "from glob import glob\n", + "import os\n", + "for f in glob(OUTPUT_DIR+os.sep+\"*\"):\n", + " if f != WEIGHTS_DIR:\n", + " shutil.rmtree(f)\n", + " print(\"Deleted\", f)\n", + "for f in glob(WEIGHTS_DIR+\"/*\"):\n", + " if not f.endswith(\".ckpt\") or not f.endswith(\".json\"):\n", + " try:\n", + " shutil.rmtree(f)\n", + " except NotADirectoryError:\n", + " continue\n", + " print(\"Deleted\", f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jXgi8HM4c-DA" + }, + "outputs": [], + "source": [ + "#@title Free runtime memory - this frees up memory and SHUTS DOWN THE ENVIRONMENT!\n", + "#@markdown Do not run this if you want to continue generating images here\n", + "exit()" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [], + "include_colab_link": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + }, + "vscode": { + "interpreter": { + "hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a" + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/examples/huggingface/main.py b/examples/huggingface/main.py index 732f077f..390e4b7b 100644 --- a/examples/huggingface/main.py +++ b/examples/huggingface/main.py @@ -1,17 +1,21 @@ -import textbase + + import textbase from textbase.message import Message from textbase import models import os from typing import List +import transformers -#load your HuggingFace API key -models.HuggingFace.api_key = "" -# or from environment variable: +# Load your HuggingFace API key +models.HuggingFace.api_key = "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" +# or load from an environment variable: # models.HuggingFace.api_key = os.getenv("HUGGING_FACE_API_KEY") # Prompt for the model -SYSTEM_PROMPT = """you are and expert in large language model (llm) field and you will answer accordingly""" +SYSTEM_PROMPT = """you are an expert in the large language model (LLM) field and you will answer accordingly""" +# Specify the Hugging Face model name +model_name = "runwayml/stable-diffusion-v1-5" @textbase.chatbot("talking-bot") def on_message(message_history: List[Message], state: dict = None): @@ -27,11 +31,11 @@ def on_message(message_history: List[Message], state: dict = None): else: state["counter"] += 1 - #Generate Hugging face model response + # Generate Hugging Face model response bot_response = models.HuggingFace.generate( system_prompt=SYSTEM_PROMPT, message_history=message_history, - model="jasondubon/HubermanGPT-small-v1" + model=model_name ) return bot_response, state diff --git a/main.py b/main.py index a4a726a3..67e7a777 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,14 @@ from textbase import models import os from typing import List +import json +import requests + + +# Load your HuggingFace API key +models.HuggingFace.api_key = "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" +# or load from an environment variable: +# models.HuggingFace.api_key = os.getenv("HUGGING_FACE_API_KEY") # Load your OpenAI API key models.OpenAI.api_key = os.getenv("OPENAI_API_KEY") @@ -13,6 +21,8 @@ SYSTEM_PROMPT = """You are chatting with an AI. There are no specific prefixes for responses, so you can ask or talk about anything you like. The AI will respond in a natural, conversational manner. Feel free to start the conversation with any question or topic, and let's have a pleasant chat! """ +# Prompt for the model +SYSTEM_PROMPT = """you are an expert in the large language model (LLM) field and you will answer accordingly""" @textbase.chatbot("talking-bot") def on_message(message_history: List[Message], state: dict = None): @@ -28,11 +38,37 @@ def on_message(message_history: List[Message], state: dict = None): else: state["counter"] += 1 - # # Generate GPT-3.5 Turbo response - bot_response = models.OpenAI.generate( - system_prompt=SYSTEM_PROMPT, - message_history=message_history, - model="gpt-3.5-turbo", - ) + try: + assert models.HuggingFace.api_key is not None, "Hugging Face API key is not set" + + headers = {"Authorization": f"Bearer {models.HuggingFace.api_key}"} + API_URL = "https://api-inference.huggingface.co/models/gpt-3.5-turbo" + + inputs = { + "past_user_inputs": [SYSTEM_PROMPT], + "generated_responses": [], + "text": "" + } + + for message in message_history: + if message.role == "user": + inputs["past_user_inputs"].append(message.content) + else: + inputs["generated_responses"].append(message.content) + + inputs["text"] = inputs["past_user_inputs"].pop(-1) + payload = { + "inputs": inputs, + "max_length": 3000, + "temperature": 0.7, + } + data = json.dumps(payload) + response = requests.request("POST", API_URL, headers=headers, data=data) + response = json.loads(response.content.decode("utf-8")) + + bot_response = response["generated_text"] + return bot_response, state - return bot_response, state \ No newline at end of file + except Exception as ex: + error_message = f"Error: {ex}" + return error_message, state diff --git a/textbase/backend.py b/textbase/backend.py index 9c482a9d..392b07fd 100644 --- a/textbase/backend.py +++ b/textbase/backend.py @@ -1,4 +1,3 @@ -# textbase/backend.py from fastapi import FastAPI from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse @@ -15,8 +14,6 @@ load_dotenv() -from .message import Message - app = FastAPI() origins = [ @@ -34,29 +31,12 @@ allow_headers=["*"], ) - @app.get("/", response_class=HTMLResponse) async def read_root(): - """ - The `read_root` function reads and returns the contents of an HTML file specified by the path - "textbase/frontend/index.html". - :return: The content of the "index.html" file located in the "textbase/frontend" directory is being - returned. - """ with open("textbase/frontend/dist/index.html") as f: return f.read() - def get_module_from_file_path(file_path: str): - """ - The function `get_module_from_file_path` takes a file path as input, loads the module from the file, - and returns the module. - - :param file_path: The file path is the path to the Python file that you want to import as a module. - It should be a string representing the absolute or relative path to the file - :type file_path: str - :return: the module that is loaded from the given file path. - """ module_name = os.path.splitext(os.path.basename(file_path))[0] spec = importlib.util.spec_from_file_location(module_name, file_path) module = importlib.util.module_from_spec(spec) @@ -64,25 +44,11 @@ def get_module_from_file_path(file_path: str): spec.loader.exec_module(module) return module +# Set the FILE_PATH environment variable to the location of your chatbot logic module +os.environ["FILE_PATH"] = "textbase\examples\huggingface\main.py" @app.post("/chat", response_model=dict) async def chat(messages: List[Message], state: dict = None): - """ - The above function is a Python API endpoint that receives a list of messages and a state dictionary, - loads a module from a file path, calls the on_message function from the module with the messages and - state, and returns the bot messages generated by the module. - - :param messages: The `messages` parameter is a list of `Message` objects. It represents the messages - exchanged between the user and the chatbot. Each `Message` object typically contains information - such as the text of the message, the sender, the timestamp, etc - :type messages: List[Message] - :param state: The `state` parameter is a dictionary that stores the state of the conversation. It - can be used to store information or context that needs to be maintained across multiple requests or - messages in the conversation. The `state` parameter is optional and can be set to `None` if not - needed - :type state: dict - :return: a list of `Message` objects. - """ file_path = os.environ.get("FILE_PATH", None) logging.info(file_path) if not file_path: @@ -90,9 +56,6 @@ async def chat(messages: List[Message], state: dict = None): module = get_module_from_file_path(file_path) - print("here", state) - - # Call the on_message function from the dynamically loaded module response = module.on_message(messages, state) if type(response) is tuple: bot_response, new_state = response @@ -103,8 +66,6 @@ async def chat(messages: List[Message], state: dict = None): elif type(response) is str: return {"botResponse": {"content": response, "role": "assistant"}} - -# Mount the static directory (frontend files) app.mount( "/assets", StaticFiles(directory="textbase/frontend/dist/assets", html=True), diff --git a/textbase/frontend/src/App.css b/textbase/frontend/src/App.css index 558ffc32..e4145601 100644 --- a/textbase/frontend/src/App.css +++ b/textbase/frontend/src/App.css @@ -5,15 +5,7 @@ text-align: center; } -.logo { - height: 6em; - padding: 1.5em; - will-change: filter; - transition: filter 300ms; -} -.logo:hover { - filter: drop-shadow(0 0 2em #646cffaa); -} + .logo.react:hover { filter: drop-shadow(0 0 2em #61dafbaa); } @@ -26,6 +18,32 @@ transform: rotate(360deg); } } +@import "~react-icons/fa"; +.icon-sun { + font-size: 24px; + margin-right: 8px; +} + +.icon-moon { + font-size: 24px; + margin-right: 8px; +} + +.icon-plane { + font-size: 20px; +} + +.icon-star { + font-size: 20px; + color: gold; + margin-left: 8px; +} +.icon-plane, +.icon-star { + font-size: 1.2rem; + margin-right: 8px; + color: #333; +} @media (prefers-reduced-motion: no-preference) { a:nth-of-type(2) .logo { diff --git a/textbase/frontend/src/App.tsx b/textbase/frontend/src/App.tsx index ae3ea1d9..2917f2b0 100644 --- a/textbase/frontend/src/App.tsx +++ b/textbase/frontend/src/App.tsx @@ -1,13 +1,117 @@ -// App.tsx -import { ThemeProvider } from './components/ThemeContext'; -import Chatbot from './components/ChatBot'; + +import React, { useState, useRef, useEffect } from "react"; +import ReactMarkdown from "react-markdown"; +import remarkGfm from "remark-gfm"; +//import { FaSun, FaMoon, FaPaperPlane, FaStar } from "react-icons/fa"; +import {GoStar} from "react-icons/go"; +import {GoStarFill} from "react-icons/go"; +import "./App.css"; + +type Message = { + content: string; + role: "user" | "assistant"; +}; + +function ChatMessage(props: { message: Message }) { + return ( +
+
+ + {props.message.content} + +
+
+ ); +} function App() { + const [input, setInput] = useState(""); + const [botState, setBotState] = useState({}); + const [history, setHistory] = useState([]); + const [darkMode, setDarkMode] = useState(false); + + const chatEndRef = useRef(null); + + useEffect(() => { + if (chatEndRef.current) { + chatEndRef.current.scrollIntoView({ behavior: "smooth" }); + } + }, [history]); + + async function chatRequest(history: Message[], botState: object) { + try { + const response = await fetch("http://localhost:4000/chat", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ messages: history, state: botState }), + }); + const content: { botResponse: Message; newState: object } = + await response.json(); + setHistory([...history, content.botResponse]); + setBotState(content.newState); + } catch (error) { + console.error("Failed to send chat history:", error); + } + } + + function chatInputHandler() { + if (!input) { + return; + } + const newMessage: Message = { + content: input, + role: "user", + }; + setHistory([...history, newMessage]); + setInput(""); + chatRequest([...history, newMessage], botState); + } + return ( - - - - +
+
+
+ {history.map((message, idx) => ( + + ))} +
+
+
+ ) => { + setInput(e.target.value); + }} + onKeyDown={(e: React.KeyboardEvent) => { + if (e.key === "Enter") { + chatInputHandler(); + } + }} + /> + +
+ +
+
+
+ +
); } diff --git a/textbase/models.py b/textbase/models.py index 93343549..fa8051d9 100644 --- a/textbase/models.py +++ b/textbase/models.py @@ -1,10 +1,12 @@ import json -import openai import requests import time import typing import textbase from textbase.message import Message + +class HuggingFace: + api_key = "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"; from langchain import HuggingFaceHub from langchain import PromptTemplate, LLMChain import os @@ -36,7 +38,6 @@ def generate( class HuggingFace: - @classmethod def generate( @@ -68,26 +69,26 @@ def generate( inputs["text"] = inputs["past_user_inputs"].pop(-1) payload = { - "inputs":inputs, + "inputs": inputs, "max_length": max_tokens, "temperature": temperature, "min_length": min_tokens, "top_k": top_k, } data = json.dumps(payload) - response = requests.request("POST", API_URL, headers=headers, data=data) - response = json.loads(response.content.decode("utf-8")) + response = requests.post(API_URL, headers=headers, json=payload) + response_data = response.json() - if response.get("error", None) == "Authorization header is invalid, use 'Bearer API_TOKEN'": + if response_data.get("error", None) == "Authorization header is invalid, use 'Bearer API_TOKEN'": print("Hugging Face API key is not correct") - if response.get("estimated_time", None): - print(f"Model is loading please wait for {response.get('estimated_time')}") - time.sleep(response.get("estimated_time")) - response = requests.request("POST", API_URL, headers=headers, data=data) - response = json.loads(response.content.decode("utf-8")) + if response_data.get("estimated_time", None): + print(f"Model is loading, please wait for {response_data.get('estimated_time')} seconds") + time.sleep(response_data.get("estimated_time")) + response = requests.post(API_URL, headers=headers, json=payload) + response_data = response.json() - return response["generated_text"] + return response_data["generated_text"] except Exception as ex: print(f"Error occured while using this model, please try using another model, Exception was {ex}") @@ -134,4 +135,4 @@ def generate( return response except Exception as ex: - print(f"Error occured while using this model, please try using another model, Exception was {ex}") \ No newline at end of file + print(f"Error occured while using this model, please try using another model, Exception was {ex}")