Skip to content

Commit

Permalink
Merge pull request #75 from JamesHWade/fix/better-index-loading
Browse files Browse the repository at this point in the history
fix: better index saving and loading
  • Loading branch information
JamesHWade authored Feb 1, 2024
2 parents 5d6e15c + b332e46 commit f46f6d0
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 190 deletions.
181 changes: 0 additions & 181 deletions R/embedding.R
Original file line number Diff line number Diff line change
Expand Up @@ -222,144 +222,18 @@ create_index <- function(domain,
}
}


#' Index All Scraped Data
#'
#' This function iterates through all the text files in a specified directory,
#' updating or creating indexes for each domain contained in the file names.
#' Allows customization of the indexing process through various parameters.
#'
#' @param overwrite A logical value determining whether to overwrite existing
#' indexes.
#' @param local_embeddings A logical indicating whether to use local embeddings
#' for indexing.
#' @param dont_ask A logical value that, if TRUE, disables interactive
#' confirmation prompts during the indexing process.
#'
#' @details The function first retrieves a list of all text files in the
#' targeted directory. For each file, it extracts the domain name from the
#' filename, prints an informative message about the indexing process for that
#' domain, and then proceeds to create or update the index for the domain
#' based on the function arguments.
#'
#' @return Invisible NULL. The function is called for its side effects.
#'
#' @examples
#' # Index all scraped data without overwriting existing indexes, using local
#' # embeddings, and without interactive prompts.
#'
#' \dontrun{
#' gpttools_index_all_scraped_data(
#' overwrite = FALSE,
#' local_embeddings = TRUE,
#' dont_ask = TRUE
#' )
#' }
#'
#' @export
gpttools_index_all_scraped_data <- function(overwrite = FALSE,
local_embeddings = TRUE,
dont_ask = TRUE) {
text_files <- list_index("text", full_path = TRUE)

purrr::walk(text_files, function(file_path) {
domain <- tools::file_path_sans_ext(basename(file_path))
cli_alert_info(glue("Creating/updating index for domain {domain}..."))
create_index(
domain = domain,
overwrite = overwrite,
dont_ask = dont_ask,
local_embeddings = local_embeddings
)
})
}


get_top_matches <- function(index, query_embedding, k = 5) {
k <- min(k, nrow(index))
index |>
dplyr::glimpse() |>
dplyr::mutate(
similarity = purrr::map_dbl(embedding, \(x) {
cli_alert_info("query embedding: {length(query_embedding)}")
cli_alert_info("text embedding: {length(unlist(x))}")
lsa::cosine(query_embedding, unlist(x))
})
) |>
dplyr::arrange(dplyr::desc(similarity)) |>
head(k)
}

#' Load Index Data for a Domain
#'
#' This function loads the index data for a given domain from a parquet file.
#'
#' @param domain A character string indicating the name of the domain.
#' @param local_embeddings A logical indicating whether to load the local
#' embeddings or the OpenAI embeddings. Defaults to FALSE.
#'
#' @return A data frame containing the index data for the specified domain.
#'
#' @export
#'
#' @examples
#' \dontrun{
#' load_index("example_domain")
#' }
load_index <- function(domain, local_embeddings = FALSE) {
if (local_embeddings) {
data_dir <-
glue::glue('{tools::R_user_dir("gpttools", which = "data")}/index/local')
} else {
data_dir <-
glue::glue('{tools::R_user_dir("gpttools", which = "data")}/index')
}

if (!dir.exists(data_dir)) {
cli_inform("No index found. Using sample index for gpttools.")

if (local_embeddings) {
sample_index <-
system.file("sample-index/local/jameshwade-github-io-gpttools.parquet",
package = "gpttools"
)
} else {
sample_index <-
system.file("sample-index/jameshwade-github-io-gpttools.parquet",
package = "gpttools"
)
}
index <- arrow::read_parquet(sample_index)
invisible(index)
}

if (domain == "All") {
arrow::open_dataset(
data_dir,
factory_options = list(selector_ignore_prefixes = "local")
) |>
tibble::as_tibble()
} else {
arrow::read_parquet(glue("{data_dir}/{domain}.parquet"))
}
}

load_index_dir <- function(dir_name) {
dir_name <- glue('{tools::R_user_dir("gpttools", which = "data")}/{dir_name}')
arrow::open_dataset(dir_name)
}

load_scraped_data <- function(dir_name, file_name) {
file_path <-
file.path(
tools::R_user_dir("gpttools", which = "data"),
dir_name,
file_name
)
arrow::read_parquet(file_path)
}


chunk_with_overlap <- function(x, chunk_size, overlap_size, doc_id, ...) {
stopifnot(is.character(x), length(x) == 1)
words <- tokenizers::tokenize_words(x, simplify = TRUE, ...)
Expand Down Expand Up @@ -390,58 +264,3 @@ chunk_with_overlap <- function(x, chunk_size, overlap_size, doc_id, ...) {
chunks <- purrr::compact(chunks)
purrr::map(chunks, \(x) stringr::str_c(x, collapse = " "))
}

#' List Index Files
#'
#' This function lists the index files in the specified directory.
#'
#' @param dir Name of the directory, defaults to "index"
#' @param full_path If TRUE, returns the full path to the index files.
#'
#' @return A character vector containing the names of the index files found in
#' the specified directory.
#'
#' @examples
#' \dontrun{
#' list_index()
#' }
#' @export
list_index <- function(dir = "index", full_path = FALSE) {
loc <- file.path(tools::R_user_dir("gpttools", "data"), dir)
cli_inform("Access your index files here: {.file {loc}}")
if (full_path) {
list.files(loc, full.names = TRUE)
} else {
list.files(loc)
}
}


#' Delete an Index File
#'
#' Interactively deletes a specified index file from a user-defined directory.
#' Presents the user with a list of available index files and prompts for
#' confirmation before deletion.

#' @export
delete_index <- function() {
files <- list_index()
if (length(files) == 0) {
cli_alert_warning("No index files found.")
return(invisible())
}
cli_alert("Select the index file you want to delete.")
to_delete <- utils::menu(files)
confirm_delete <-
ui_yeah("Are you sure you want to delete {files[to_delete]}?")
if (confirm_delete) {
file.remove(file.path(
tools::R_user_dir("gpttools", "data"),
"index",
files[to_delete]
))
cli_alert_success("Index deleted.")
} else {
cli_alert_warning("Index not deleted.")
}
}
17 changes: 11 additions & 6 deletions R/history.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ check_context <- function(context) {
#' @param index Index to look for context.
#' @param add_context Whether to add context to the query or not. Default is
#' TRUE.
#' @param check_context Whether to check if context is needed. Default is FALSE.
#' @param chat_history Chat history dataframe for reference.
#' @param history_name Name of the file where chat history is stored.
#' @param session_history Session history data for reference.
Expand Down Expand Up @@ -189,6 +190,7 @@ chat_with_context <- function(query,
model = "gpt-4-turbo-preview",
index = NULL,
add_context = TRUE,
check_context = FASLE,
chat_history = NULL,
history_name = "chat_history",
session_history = NULL,
Expand All @@ -202,11 +204,15 @@ chat_with_context <- function(query,
embedding_model = NULL) {
arg_match(task, c("Context Only", "Permissive Chat"))

need_context <- is_context_needed(
user_prompt = query,
service = service,
model = model
)
if (rlang::is_true(check_context)) {
need_context <- is_context_needed(
user_prompt = query,
service = service,
model = model
)
} else {
need_context <- TRUE
}

if (rlang::is_true(add_context) || rlang::is_true(add_history)) {
cli_alert_info("Creating embedding from query.")
Expand All @@ -227,7 +233,6 @@ chat_with_context <- function(query,

context <-
full_context |>
dplyr::glimpse() |>
dplyr::select(source, link, chunks) |>
purrr::pmap(\(source, link, chunks) {
glue::glue("Source: {source}
Expand Down
Loading

0 comments on commit f46f6d0

Please sign in to comment.