Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore functionality of Imaginate #1760

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions node-graph/gcore/src/raster/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ impl<P: Pixel> Image<P> {
}
}

pub fn from_raw_buffer(width: u32, height: u32, data: Vec<u8>) -> Self {
Self {
width,
height,
data: bytemuck::cast_vec(data),
base64_string: None,
}
}

pub fn as_slice(&self) -> ImageSlice<P> {
ImageSlice {
width: self.width,
Expand All @@ -147,13 +156,16 @@ impl Image<Color> {
base64_string: None,
}
}
}

impl Image<SRGBA8> {
pub fn to_png(&self) -> Vec<u8> {
use ::image::ImageEncoder;
let (data, width, height) = self.to_flat_u8();
let mut png = Vec::new();
let encoder = ::image::codecs::png::PngEncoder::new(&mut png);
encoder.write_image(&data, width, height, ::image::ExtendedColorType::Rgba8).expect("failed to encode image as png");
encoder
.write_image(bytemuck::cast_slice(self.data.as_slice()), self.width, self.height, ::image::ColorType::Rgba8)
.expect("failed to encode image as png");
png
}
}
Expand Down Expand Up @@ -214,6 +226,15 @@ where

(result, *width, *height)
}

pub fn to_png(&self) -> Vec<u8> {
use ::image::ImageEncoder;
let (data, width, height) = self.to_flat_u8();
let mut png = Vec::new();
let encoder = ::image::codecs::png::PngEncoder::new(&mut png);
encoder.write_image(&data, width, height, ::image::ColorType::Rgba8).expect("failed to encode image as png");
png
}
}

impl<P: Pixel> IntoIterator for Image<P> {
Expand Down
1 change: 1 addition & 0 deletions node-graph/gstd/src/http.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::Node;
use graphene_core::raster::{ImageFrame, SRGBA8};

pub struct GetNode;

Expand Down
138 changes: 138 additions & 0 deletions node-graph/gstd/src/imaginate_v2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
use graphene_core::raster::{Image, ImageFrame, Pixel, SRGBA8};

use crate::Node;

async fn image_to_image(image_frame: ImageFrame<SRGBA8>, prompt: String) -> reqwest::Result<ImageFrame<SRGBA8>> {
let png_bytes = image_frame.image.to_png();
//let base64 = base64::encode(png_bytes);
// post to cloudflare image to image endpoint using reqwest
let payload = PayloadBuilder::new()
.guidance(7.5)
.image(png_bytes.to_vec())
//.mask(png_bytes.to_vec())
.num_steps(20)
.prompt(prompt)
.strength(1);

let client = Client::new();
let account_id = "xxx";
let api_key = "123";
let request = client
//.post(format!("https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/@cf/bytedance/stable-diffusion-xl-base-1.0"))
//.post(format!("https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/@cf/stabilityai/stable-diffusion-xl-base-1.0"))
/*.post(format!(
"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/@cf/runwayml/stable-diffusion-v1-5-inpainting"
))*/
.post(format!("https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/@cf/runwayml/stable-diffusion-v1-5-img2img"))
.json(&payload)
.header("Authorization", format!("Bearer {api_key}"));
//println!("{}", serde_json::to_string(&payload).unwrap());
let response = dbg!(request).send().await?;

#[derive(Debug, serde::Deserialize)]
struct Response {
result: String,
success: bool,
};

match response.error_for_status_ref() {
Ok(_) => (),
Err(_) => panic!("{}", response.text().await?),
}
//let text: Response = response.json().await?;
/*let text = response.text().await?;
let text = Response {
result: serde_json::Value::String(text),
success: false,
};
dbg!(&text);*/

let bytes = response.bytes().await?;
//let bytes = &[];

let image = image::load_from_memory_with_format(&bytes[..], image::ImageFormat::Png).unwrap();
let width = image.width();
let height = image.height();
let image = image.to_rgba8();
let data = image.as_raw();
let color_data = bytemuck::cast_slice(data).to_owned();
let image = Image {
width,
height,
data: color_data,
base64_string: None,
};

let image_frame = ImageFrame { image, ..image_frame };
Ok(image_frame)
}
use reqwest::Client;
use serde::Serialize;

#[derive(Default, Serialize)]
struct PayloadBuilder {
#[serde(skip_serializing_if = "Option::is_none")]
guidance: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
image: Option<Vec<u8>>,
#[serde(skip_serializing_if = "Option::is_none")]
mask: Option<Vec<u8>>,
#[serde(skip_serializing_if = "Option::is_none")]
num_steps: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
strength: Option<u32>,
}

impl PayloadBuilder {
fn new() -> Self {
Self::default()
}

fn guidance(mut self, value: f64) -> Self {
self.guidance = Some(value);
self
}

fn image(mut self, value: Vec<u8>) -> Self {
self.image = Some(value);
self
}

fn mask(mut self, value: Vec<u8>) -> Self {
self.mask = Some(value);
self
}

fn num_steps(mut self, value: u32) -> Self {
self.num_steps = Some(value);
self
}

fn prompt(mut self, value: String) -> Self {
self.prompt = Some(value);
self
}

fn strength(mut self, value: u32) -> Self {
self.strength = Some(value);
self
}
}

#[cfg(test)]
mod test {
use super::*;
use graphene_core::{raster::Image, Color};
#[tokio::test]
async fn test_cloudflare() {
let test_image = ImageFrame {
image: Image::new(1024, 1024, SRGBA8::from(Color::RED)),
..Default::default()
};
let result = image_to_image(test_image, "make green".into()).await;
dbg!(result.unwrap());
panic!("show result");
}
}
3 changes: 3 additions & 0 deletions node-graph/gstd/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ pub mod vector;

pub mod http;

#[cfg(feature = "serde")]
pub mod imaginate_v2;

pub mod any;

#[cfg(feature = "gpu")]
Expand Down
12 changes: 6 additions & 6 deletions shell.nix
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,8 @@ let
# wasm-pack needs this
extensions = [ "rust-src" "rust-analyzer" "clippy"];
};
in
# Make a shell with the dependencies we need
pkgs.mkShell {
packages = with pkgs; [

packages = with pkgs; [
rustc-wasm
nodejs
cargo
Expand Down Expand Up @@ -72,9 +70,11 @@ in
# Use Mold as a linker
mold
];

in
# Make a shell with the dependencies we need
pkgs.mkShell {
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath packages;
# Hacky way to run Cargo through Mold
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath [pkgs.openssl pkgs.vulkan-loader pkgs.libxkbcommon pkgs.llvmPackages.libcxxStdenv pkgs.gcc-unwrapped.lib pkgs.llvm pkgs.libraw];
shellHook = ''
alias cargo='mold --run cargo'
'';
Expand Down
Loading