Skip to content

Commit

Permalink
feat: improve BaseNode (#848)
Browse files Browse the repository at this point in the history
  • Loading branch information
himself65 authored May 16, 2024
1 parent 10c8348 commit 5124186
Show file tree
Hide file tree
Showing 19 changed files with 293 additions and 90 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ jobs:
- name: Run Type Check
run: pnpm run type-check
- name: Run Circular Dependency Check
run: pnpm run circular-check
working-directory: ./packages/core
run: pnpm dlx turbo run circular-check
- uses: actions/upload-artifact@v3
if: failure()
with:
Expand Down
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"eslint-plugin-react": "7.34.1",
"husky": "^9.0.11",
"lint-staged": "^15.2.2",
"madge": "^7.0.0",
"prettier": "^3.2.5",
"prettier-plugin-organize-imports": "^3.2.4",
"turbo": "^1.13.3",
Expand Down
7 changes: 7 additions & 0 deletions packages/core/.madgerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"detectiveOptions": {
"ts": {
"skipTypeImports": true
}
}
}
1 change: 0 additions & 1 deletion packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
"@swc/core": "^1.5.5",
"concurrently": "^8.2.2",
"glob": "^10.3.12",
"madge": "^7.0.0",
"typescript": "^5.4.5"
},
"engines": {
Expand Down
148 changes: 92 additions & 56 deletions packages/core/src/Node.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { createSHA256, path, randomUUID } from "@llamaindex/env";
import _ from "lodash";
import { chunkSizeCheck, lazyInitHash } from "./internal/decorator/node.js";

export enum NodeRelationship {
SOURCE = "SOURCE",
Expand Down Expand Up @@ -37,6 +37,16 @@ export type RelatedNodeType<T extends Metadata = Metadata> =
| RelatedNodeInfo<T>
| RelatedNodeInfo<T>[];

export type BaseNodeParams<T extends Metadata = Metadata> = {
id_?: string;
metadata?: T;
excludedEmbedMetadataKeys?: string[];
excludedLlmMetadataKeys?: string[];
relationships?: Partial<Record<NodeRelationship, RelatedNodeType<T>>>;
hash?: string;
embedding?: number[];
};

/**
* Generic abstract class for retrievable nodes
*/
Expand All @@ -47,21 +57,37 @@ export abstract class BaseNode<T extends Metadata = Metadata> {
*
* Set to a UUID by default.
*/
id_: string = randomUUID();
id_: string;
embedding?: number[];

// Metadata fields
metadata: T = {} as T;
excludedEmbedMetadataKeys: string[] = [];
excludedLlmMetadataKeys: string[] = [];
relationships: Partial<Record<NodeRelationship, RelatedNodeType<T>>> = {};
hash: string = "";

constructor(init?: Partial<BaseNode<T>>) {
Object.assign(this, init);
}

abstract getType(): ObjectType;
metadata: T;
excludedEmbedMetadataKeys: string[];
excludedLlmMetadataKeys: string[];
relationships: Partial<Record<NodeRelationship, RelatedNodeType<T>>>;

@lazyInitHash
accessor hash: string = "";

protected constructor(init?: BaseNodeParams<T>) {
const {
id_,
metadata,
excludedEmbedMetadataKeys,
excludedLlmMetadataKeys,
relationships,
hash,
embedding,
} = init || {};
this.id_ = id_ ?? randomUUID();
this.metadata = metadata ?? ({} as T);
this.excludedEmbedMetadataKeys = excludedEmbedMetadataKeys ?? [];
this.excludedLlmMetadataKeys = excludedLlmMetadataKeys ?? [];
this.relationships = relationships ?? {};
this.embedding = embedding;
}

abstract get type(): ObjectType;

abstract getContent(metadataMode: MetadataMode): string;
abstract getMetadataStr(metadataMode: MetadataMode): string;
Expand Down Expand Up @@ -146,7 +172,12 @@ export abstract class BaseNode<T extends Metadata = Metadata> {
* @see toMutableJSON - use to return a mutable JSON instead
*/
toJSON(): Record<string, any> {
return { ...this, type: this.getType() };
return {
...this,
type: this.type,
// hash is an accessor property, so it's not included in the rest operator
hash: this.hash,
};
}

clone(): BaseNode {
Expand All @@ -159,32 +190,43 @@ export abstract class BaseNode<T extends Metadata = Metadata> {
* @return {Record<string, any>} - The JSON representation of the object.
*/
toMutableJSON(): Record<string, any> {
return _.cloneDeep(this.toJSON());
return structuredClone(this.toJSON());
}
}

export type TextNodeParams<T extends Metadata = Metadata> =
BaseNodeParams<T> & {
text?: string;
textTemplate?: string;
startCharIdx?: number;
endCharIdx?: number;
metadataSeparator?: string;
};

/**
* TextNode is the default node type for text. Most common node type in LlamaIndex.TS
*/
export class TextNode<T extends Metadata = Metadata> extends BaseNode<T> {
text: string = "";
textTemplate: string = "";
text: string;
textTemplate: string;

startCharIdx?: number;
endCharIdx?: number;
// textTemplate: NOTE write your own formatter if needed
// metadataTemplate: NOTE write your own formatter if needed
metadataSeparator: string = "\n";
metadataSeparator: string;

constructor(init?: Partial<TextNode<T>>) {
constructor(init: TextNodeParams<T> = {}) {
super(init);
Object.assign(this, init);

if (new.target === TextNode) {
// Don't generate the hash repeatedly so only do it if this is
// constructing the derived class
this.hash = init?.hash ?? this.generateHash();
const { text, textTemplate, startCharIdx, endCharIdx, metadataSeparator } =
init;
this.text = text ?? "";
this.textTemplate = textTemplate ?? "";
if (startCharIdx) {
this.startCharIdx = startCharIdx;
}
this.endCharIdx = endCharIdx;
this.metadataSeparator = metadataSeparator ?? "\n";
}

/**
Expand All @@ -194,18 +236,19 @@ export class TextNode<T extends Metadata = Metadata> extends BaseNode<T> {
*/
generateHash() {
const hashFunction = createSHA256();
hashFunction.update(`type=${this.getType()}`);
hashFunction.update(`type=${this.type}`);
hashFunction.update(
`startCharIdx=${this.startCharIdx} endCharIdx=${this.endCharIdx}`,
);
hashFunction.update(this.getContent(MetadataMode.ALL));
return hashFunction.digest();
}

getType(): ObjectType {
get type() {
return ObjectType.TEXT;
}

@chunkSizeCheck
getContent(metadataMode: MetadataMode = MetadataMode.NONE): string {
const metadataStr = this.getMetadataStr(metadataMode).trim();
return `${metadataStr}\n\n${this.text}`.trim();
Expand Down Expand Up @@ -246,19 +289,21 @@ export class TextNode<T extends Metadata = Metadata> extends BaseNode<T> {
}
}

export type IndexNodeParams<T extends Metadata = Metadata> =
TextNodeParams<T> & {
indexId: string;
};

export class IndexNode<T extends Metadata = Metadata> extends TextNode<T> {
indexId: string = "";
indexId: string;

constructor(init?: Partial<IndexNode<T>>) {
constructor(init?: IndexNodeParams<T>) {
super(init);
Object.assign(this, init);

if (new.target === IndexNode) {
this.hash = init?.hash ?? this.generateHash();
}
const { indexId } = init || {};
this.indexId = indexId ?? "";
}

getType(): ObjectType {
get type() {
return ObjectType.INDEX;
}
}
Expand All @@ -267,16 +312,11 @@ export class IndexNode<T extends Metadata = Metadata> extends TextNode<T> {
* A document is just a special text node with a docId.
*/
export class Document<T extends Metadata = Metadata> extends TextNode<T> {
constructor(init?: Partial<Document<T>>) {
constructor(init?: TextNodeParams<T>) {
super(init);
Object.assign(this, init);

if (new.target === Document) {
this.hash = init?.hash ?? this.generateHash();
}
}

getType() {
get type() {
return ObjectType.DOCUMENT;
}
}
Expand All @@ -303,21 +343,21 @@ export function jsonToNode(json: any, type?: ObjectType) {

export type ImageType = string | Blob | URL;

export type ImageNodeConstructorProps<T extends Metadata> = Pick<
ImageNode<T>,
"image" | "id_"
> &
Partial<ImageNode<T>>;
export type ImageNodeParams<T extends Metadata = Metadata> =
TextNodeParams<T> & {
image: ImageType;
};

export class ImageNode<T extends Metadata = Metadata> extends TextNode<T> {
image: ImageType; // image as blob

constructor(init: ImageNodeConstructorProps<T>) {
constructor(init: ImageNodeParams<T>) {
super(init);
this.image = init.image;
const { image } = init;
this.image = image;
}

getType(): ObjectType {
get type() {
return ObjectType.IMAGE;
}

Expand Down Expand Up @@ -360,15 +400,11 @@ export class ImageNode<T extends Metadata = Metadata> extends TextNode<T> {
}

export class ImageDocument<T extends Metadata = Metadata> extends ImageNode<T> {
constructor(init: ImageNodeConstructorProps<T>) {
constructor(init: ImageNodeParams<T>) {
super(init);

if (new.target === ImageDocument) {
this.hash = init?.hash ?? this.generateHash();
}
}

getType() {
get type() {
return ObjectType.IMAGE_DOCUMENT;
}
}
Expand Down
17 changes: 10 additions & 7 deletions packages/core/src/Settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ import {
setCallbackManager,
withCallbackManager,
} from "./internal/settings/CallbackManager.js";
import {
getChunkSize,
setChunkSize,
withChunkSize,
} from "./internal/settings/chunk-size.js";
import type { LLM } from "./llm/types.js";
import type { NodeParser } from "./nodeParsers/types.js";

Expand Down Expand Up @@ -41,14 +46,12 @@ class GlobalSettings implements Config {
#promptHelper: PromptHelper | null = null;
#embedModel: BaseEmbedding | null = null;
#nodeParser: NodeParser | null = null;
#chunkSize?: number;
#chunkOverlap?: number;

#llmAsyncLocalStorage = new AsyncLocalStorage<LLM>();
#promptHelperAsyncLocalStorage = new AsyncLocalStorage<PromptHelper>();
#embedModelAsyncLocalStorage = new AsyncLocalStorage<BaseEmbedding>();
#nodeParserAsyncLocalStorage = new AsyncLocalStorage<NodeParser>();
#chunkSizeAsyncLocalStorage = new AsyncLocalStorage<number>();
#chunkOverlapAsyncLocalStorage = new AsyncLocalStorage<number>();
#promptAsyncLocalStorage = new AsyncLocalStorage<PromptConfig>();

Expand Down Expand Up @@ -115,8 +118,8 @@ class GlobalSettings implements Config {
get nodeParser(): NodeParser {
if (this.#nodeParser === null) {
this.#nodeParser = new SimpleNodeParser({
chunkSize: this.#chunkSize,
chunkOverlap: this.#chunkOverlap,
chunkSize: this.chunkSize,
chunkOverlap: this.chunkOverlap,
});
}

Expand Down Expand Up @@ -147,15 +150,15 @@ class GlobalSettings implements Config {
}

set chunkSize(chunkSize: number | undefined) {
this.#chunkSize = chunkSize;
setChunkSize(chunkSize);
}

get chunkSize(): number | undefined {
return this.#chunkSizeAsyncLocalStorage.getStore() ?? this.#chunkSize;
return getChunkSize();
}

withChunkSize<Result>(chunkSize: number, fn: () => Result): Result {
return this.#chunkSizeAsyncLocalStorage.run(chunkSize, fn);
return withChunkSize(chunkSize, fn);
}

get chunkOverlap(): number | undefined {
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/indices/vectorStore/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
// NOTE: if the vector store keeps text,
// we only need to add image and index nodes
for (let i = 0; i < nodes.length; ++i) {
const type = nodes[i].getType();
const { type } = nodes[i];
if (
!vectorStore.storesText ||
type === ObjectType.INDEX ||
Expand Down
Loading

0 comments on commit 5124186

Please sign in to comment.