Skip to content

Commit

Permalink
fix: potential RPC connection issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Supremesource committed May 15, 2024
1 parent 9a33313 commit 0e6a24a
Show file tree
Hide file tree
Showing 13 changed files with 67 additions and 69 deletions.
15 changes: 6 additions & 9 deletions src/synthia/cli.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import typer
from typing import Annotated, Optional
from rich.console import Console

import typer
from communex._common import get_node_url
from communex.client import CommuneClient
from communex.compat.key import classic_load_key
from rich.console import Console

from synthia.validator.text_validator import (
TextValidator,
ValidatorSettings,
get_synthia_netuid,
ClaudeProviders,
)

from synthia.validator.text_validator import (ClaudeProviders, TextValidator,
ValidatorSettings,
get_synthia_netuid)

app = typer.Typer()

Expand Down
3 changes: 2 additions & 1 deletion src/synthia/miner/BaseLLM.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from communex.module.module import Module, endpoint # type: ignore
from abc import ABC, abstractmethod

from communex.module.module import Module, endpoint # type: ignore
from fastapi import HTTPException


Expand Down
1 change: 1 addition & 0 deletions src/synthia/miner/_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pydantic_settings import BaseSettings


class AnthropicSettings(BaseSettings):
api_key: str
model: str = "claude-3-opus-20240229"
Expand Down
15 changes: 6 additions & 9 deletions src/synthia/miner/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import json
from typing import Any

import requests
from anthropic import Anthropic
from communex.module.module import Module, endpoint # type: ignore
from anthropic._types import NotGiven
from communex.key import generate_keypair # type: ignore
from communex.module.module import Module, endpoint # type: ignore
from keylimiter import TokenBucketLimiter

import requests
import json



from ._config import AnthropicSettings, OpenrouterSettings # Import the AnthropicSettings class from config
from ..utils import log # Import the log function from utils
from ._config import ( # Import the AnthropicSettings class from config
AnthropicSettings, OpenrouterSettings)
from .BaseLLM import BaseLLM


Expand Down Expand Up @@ -136,9 +134,8 @@ def prompt(self, user_prompt: str, system_prompt: str | None = None):


if __name__ == "__main__":
from communex.module.server import ModuleServer # type: ignore

import uvicorn
from communex.module.server import ModuleServer # type: ignore
key = generate_keypair()
log(f"Running module with key {key.ss58_address}")
claude = OpenrouterModule()
Expand Down
3 changes: 2 additions & 1 deletion src/synthia/tests/connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio

from communex.compat.key import classic_load_key # type: ignore
from communex.module.client import ModuleClient # type: ignore
import asyncio

if __name__ == "__main__":
from communex.compat.key import classic_load_key # type: ignore
Expand Down
1 change: 1 addition & 0 deletions src/synthia/tests/distribution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math


def sigmoid(x: float):
return 1 / (1 + math.exp(-x))

Expand Down
1 change: 0 additions & 1 deletion src/synthia/tests/vote.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from communex.client import CommuneClient # type: ignore
from communex.compat.key import classic_load_key # type: ignore


client = CommuneClient("wss://testnet-commune-api-node-0.communeai.net")


Expand Down
8 changes: 4 additions & 4 deletions src/synthia/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import sys
import datetime
import random
from time import sleep
import sys
import time
from typing import Callable, TypeVar, ParamSpec, Literal, Any
import datetime
from functools import wraps
from time import sleep
from typing import Any, Callable, Literal, ParamSpec, TypeVar

T = TypeVar("T")
T1 = TypeVar("T1")
Expand Down
2 changes: 1 addition & 1 deletion src/synthia/validator/_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from communex.compat.types import Ss58Address #  type: ignore
from communex.compat.types import Ss58Address # type: ignore
from pydantic_settings import BaseSettings


Expand Down
5 changes: 2 additions & 3 deletions src/synthia/validator/generate_data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import json
from typing import cast, Any
from typing import cast

from .meta_prompt import explanation_prompt
from ..miner.BaseLLM import BaseLLM
from .meta_prompt import explanation_prompt


class InputGenerator:
Expand Down
1 change: 1 addition & 0 deletions src/synthia/validator/sigmoid.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math


def sigmoid(x: float):
return 1 / (1 + math.exp(-x))

Expand Down
9 changes: 4 additions & 5 deletions src/synthia/validator/similarity.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Protocol
from dataclasses import dataclass
from typing import Any
from typing import Any, Protocol

from pydantic_settings import BaseSettings
import openai
import numpy
from transformers import pipeline, Pipeline # type: ignore
import openai
from pydantic_settings import BaseSettings
from transformers import Pipeline, pipeline # type: ignore

# from ..utils import log

Expand Down
72 changes: 37 additions & 35 deletions src/synthia/validator/text_validator.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,36 @@
import asyncio
import random
import re
import time
import random
from enum import Enum
from dataclasses import dataclass

from enum import Enum

import numpy as np
import requests
from communex._common import get_node_url # type: ignore
from communex.client import CommuneClient # type: ignore
from communex.compat.key import check_ss58_address # type: ignore
from communex.module.client import ModuleClient # type: ignore
from communex.module.module import Module # type: ignore
from communex.compat.key import check_ss58_address # type: ignore
from communex.types import Ss58Address # type: ignore
from fuzzywuzzy import fuzz # type: ignore
from substrateinterface import Keypair # type: ignore

from ..miner._config import AnthropicSettings, OpenrouterSettings
from ..miner.anthropic import AnthropicModule, OpenrouterModule
from ..utils import retry, log
from ..utils import log, retry
from ._config import ValidatorSettings
from .generate_data import InputGenerator
from .meta_prompt import get_miner_prompt, Criteria
from .similarity import Embedder, OpenAIEmbedder, OpenAISettings, euclidean_distance
from .meta_prompt import Criteria, get_miner_prompt
from .sigmoid import threshold_sigmoid_reward_distribution
from .similarity import (Embedder, OpenAIEmbedder, OpenAISettings,
euclidean_distance)

# TODO: make it match ipv6
IP_REGEX = re.compile(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d+")
NUM_QUESTIONS_PER_CYCLE=5
MINIMUM_DATASET_SCORE=0.7
NUM_QUESTIONS_PER_CYCLE = 5
MINIMUM_DATASET_SCORE = 0.7


def set_weights(
score_dict: dict[int, float], netuid: int, client: CommuneClient, key: Keypair
Expand Down Expand Up @@ -87,15 +89,16 @@ def cut_to_max_allowed_weights(
settings = ValidatorSettings() # type: ignore

max_allowed_weights = settings.max_allowed_weights
# sort the score by highest to lowest

# sort the score by highest to lowest
sorted_scores = sorted(score_dict.items(), key=lambda x: x[1], reverse=True)

# cut to max_allowed_weights
# cut to max_allowed_weights
cut_scores = sorted_scores[:max_allowed_weights]

return dict(cut_scores)


def extract_address(string: str):
"""
Extracts an address from a string.
Expand Down Expand Up @@ -129,10 +132,12 @@ def get_ip_port(modules_adresses: dict[int, str]):
}
return ip_port


class ClaudeProviders(Enum):
ANTHROPIC = "anthropic"
ANTHROPIC = "anthropic"
OPENROUTER = "openrouter"


@dataclass
class ValidationDataset:
prompt: str
Expand All @@ -142,12 +147,14 @@ class ValidationDataset:
chosen_subject: str
embedded_val_answer: list[float]


@dataclass
class ModuleInfo:
uid: int
address: list[str] #actually a tuple[str, str] but as a list
address: list[str] # actually a tuple[str, str] but as a list
key: Ss58Address


class TextValidator(Module):
"""A class for validating text data using a Synthia network.
Expand Down Expand Up @@ -213,7 +220,7 @@ def get_modules(self, client: CommuneClient, netuid: int) -> dict[int, str]:
return module_addreses

def _get_validation_dataset(self, settings: ValidatorSettings, size: int):

# TODO: make ValidatorSettings and the miners settings inherit from a
# common protocol
match self.provider:
Expand All @@ -229,7 +236,7 @@ def _get_validation_dataset(self, settings: ValidatorSettings, size: int):
claude_settings.max_tokens = settings.max_tokens
claude_settings.model = self.val_model
claude = OpenrouterModule(claude_settings)

ig = InputGenerator(claude)

retrier = retry(4, [Exception])
Expand All @@ -242,9 +249,12 @@ def _get_validation_dataset(self, settings: ValidatorSettings, size: int):
subject, val_answer = self._split_val_subject(explanations)
embedded_val_answer = self.embedder.get_embedding(val_answer)
val_dataset = ValidationDataset(
prompt=prompt, criteria=criteria, question_age=questions_age,
val_answer=val_answer, chosen_subject=subject,
embedded_val_answer=embedded_val_answer
prompt=prompt,
criteria=criteria,
question_age=questions_age,
val_answer=val_answer,
chosen_subject=subject,
embedded_val_answer=embedded_val_answer,
)
validation_list.append(val_dataset)
return validation_list
Expand All @@ -258,17 +268,14 @@ async def _get_miner_prediction(
module_ip, module_port = connection

question = get_miner_prompt(
val_info.criteria,
val_info.chosen_subject,
len(val_info.val_answer)
val_info.criteria, val_info.chosen_subject, len(val_info.val_answer)
)
client = ModuleClient(module_ip, int(module_port), self.key)
try:
miner_answer = await client.call(
"generate", miner_key,
{"prompt": question}, timeout=self.call_timeout
)

"generate", miner_key, {"prompt": question}, timeout=self.call_timeout
)

miner_answer = miner_answer["answer"]

except Exception as e:
Expand Down Expand Up @@ -341,13 +348,12 @@ async def validate_step(
syntia_netuid: The netuid of the Synthia subnet.
"""

self.client = CommuneClient(get_node_url())
modules_adresses = self.get_modules(self.client, syntia_netuid)
modules_keys = self.client.query_map_key(syntia_netuid)
val_ss58 = self.key.ss58_address
if val_ss58 not in modules_keys.values():
raise RuntimeError(
f"validator key {val_ss58} is not registered in subnet"
)
raise RuntimeError(f"validator key {val_ss58} is not registered in subnet")
modules_info: dict[int, ModuleInfo] = {}

modules_filtered_address = get_ip_port(modules_adresses)
Expand All @@ -364,16 +370,13 @@ async def validate_step(
hf_data_list: list[dict[str, str]] = []
# == Validation loop / Scoring ==
val_dataset = self._get_validation_dataset(settings, NUM_QUESTIONS_PER_CYCLE)



log(f"Selected the following miners: {modules_info.keys()}")
futures: list[asyncio.Task[tuple[str | None, ValidationDataset]]] = []
for mod_info in modules_info.values():
val_info = random.choice(val_dataset)
future = asyncio.create_task(
self._get_miner_prediction(
val_info, (mod_info.address, mod_info.key)
)
self._get_miner_prediction(val_info, (mod_info.address, mod_info.key))
)
futures.append(future)
miner_answers = await asyncio.gather(*futures)
Expand Down Expand Up @@ -455,4 +458,3 @@ def validation_loop(self, settings: ValidatorSettings | None = None) -> None:
sleep_time = settings.iteration_interval - elapsed
log(f"Sleeping for {sleep_time}")
time.sleep(sleep_time)

0 comments on commit 0e6a24a

Please sign in to comment.