diff --git a/src/cli.ts b/src/cli.ts index 2713d77..d544497 100644 --- a/src/cli.ts +++ b/src/cli.ts @@ -17,10 +17,17 @@ import { StdioServerTransport } from "./server/stdio.js"; import { ListResourcesResultSchema } from "./types.js"; async function runClient(url_or_command: string, args: string[]) { - const client = new Client({ - name: "mcp-typescript test client", - version: "0.1.0", - }); + const client = new Client( + { + name: "mcp-typescript test client", + version: "0.1.0", + }, + { + capabilities: { + sampling: {}, + }, + }, + ); let clientTransport; @@ -63,10 +70,15 @@ async function runServer(port: number | null) { console.log("Got new SSE connection"); const transport = new SSEServerTransport("/message", res); - const server = new Server({ - name: "mcp-typescript test server", - version: "0.1.0", - }); + const server = new Server( + { + name: "mcp-typescript test server", + version: "0.1.0", + }, + { + capabilities: {}, + }, + ); servers.push(server); @@ -97,10 +109,20 @@ async function runServer(port: number | null) { console.log(`Server running on http://localhost:${port}/sse`); }); } else { - const server = new Server({ - name: "mcp-typescript test server", - version: "0.1.0", - }); + const server = new Server( + { + name: "mcp-typescript test server", + version: "0.1.0", + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {}, + }, + }, + ); const transport = new StdioServerTransport(); await server.connect(transport); diff --git a/src/client/index.test.ts b/src/client/index.test.ts index d93ca39..5610a62 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -3,11 +3,364 @@ /* eslint-disable @typescript-eslint/no-unused-expressions */ import { Client } from "./index.js"; import { z } from "zod"; -import { RequestSchema, NotificationSchema, ResultSchema } from "../types.js"; +import { + RequestSchema, + NotificationSchema, + ResultSchema, + LATEST_PROTOCOL_VERSION, + SUPPORTED_PROTOCOL_VERSIONS, + InitializeRequestSchema, + ListResourcesRequestSchema, + ListToolsRequestSchema, + CreateMessageRequestSchema, + ListRootsRequestSchema, +} from "../types.js"; +import { Transport } from "../shared/transport.js"; +import { Server } from "../server/index.js"; +import { InMemoryTransport } from "../inMemory.js"; + +test("should initialize with matching protocol version", async () => { + const clientTransport: Transport = { + start: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + send: jest.fn().mockImplementation((message) => { + if (message.method === "initialize") { + clientTransport.onmessage?.({ + jsonrpc: "2.0", + id: message.id, + result: { + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: {}, + serverInfo: { + name: "test", + version: "1.0", + }, + }, + }); + } + return Promise.resolve(); + }), + }; + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + sampling: {}, + }, + }, + ); + + await client.connect(clientTransport); + + // Should have sent initialize with latest version + expect(clientTransport.send).toHaveBeenCalledWith( + expect.objectContaining({ + method: "initialize", + params: expect.objectContaining({ + protocolVersion: LATEST_PROTOCOL_VERSION, + }), + }), + ); +}); + +test("should initialize with supported older protocol version", async () => { + const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1]; + const clientTransport: Transport = { + start: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + send: jest.fn().mockImplementation((message) => { + if (message.method === "initialize") { + clientTransport.onmessage?.({ + jsonrpc: "2.0", + id: message.id, + result: { + protocolVersion: OLD_VERSION, + capabilities: {}, + serverInfo: { + name: "test", + version: "1.0", + }, + }, + }); + } + return Promise.resolve(); + }), + }; + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + sampling: {}, + }, + }, + ); + + await client.connect(clientTransport); + + // Connection should succeed with the older version + expect(client.getServerVersion()).toEqual({ + name: "test", + version: "1.0", + }); +}); + +test("should reject unsupported protocol version", async () => { + const clientTransport: Transport = { + start: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + send: jest.fn().mockImplementation((message) => { + if (message.method === "initialize") { + clientTransport.onmessage?.({ + jsonrpc: "2.0", + id: message.id, + result: { + protocolVersion: "invalid-version", + capabilities: {}, + serverInfo: { + name: "test", + version: "1.0", + }, + }, + }); + } + return Promise.resolve(); + }), + }; + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + sampling: {}, + }, + }, + ); + + await expect(client.connect(clientTransport)).rejects.toThrow( + "Server's protocol version is not supported: invalid-version", + ); + + expect(clientTransport.close).toHaveBeenCalled(); +}); + +test("should respect server capabilities", async () => { + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + capabilities: { + resources: {}, + tools: {}, + }, + }, + ); + + server.setRequestHandler(InitializeRequestSchema, (_request) => ({ + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: { + resources: {}, + tools: {}, + }, + serverInfo: { + name: "test", + version: "1.0", + }, + })); + + server.setRequestHandler(ListResourcesRequestSchema, () => ({ + resources: [], + })); + + server.setRequestHandler(ListToolsRequestSchema, () => ({ + tools: [], + })); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + sampling: {}, + }, + enforceStrictCapabilities: true, + }, + ); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + // Server supports resources and tools, but not prompts + expect(client.getServerCapabilities()).toEqual({ + resources: {}, + tools: {}, + }); + + // These should work + await expect(client.listResources()).resolves.not.toThrow(); + await expect(client.listTools()).resolves.not.toThrow(); + + // This should throw because prompts are not supported + await expect(client.listPrompts()).rejects.toThrow( + "Server does not support prompts", + ); +}); + +test("should respect client notification capabilities", async () => { + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + capabilities: {}, + }, + ); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + roots: { + listChanged: true, + }, + }, + }, + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + // This should work because the client has the roots.listChanged capability + await expect(client.sendRootsListChanged()).resolves.not.toThrow(); + + // Create a new client without the roots.listChanged capability + const clientWithoutCapability = new Client( + { + name: "test client without capability", + version: "1.0", + }, + { + capabilities: {}, + enforceStrictCapabilities: true, + }, + ); + + await clientWithoutCapability.connect(clientTransport); + + // This should throw because the client doesn't have the roots.listChanged capability + await expect(clientWithoutCapability.sendRootsListChanged()).rejects.toThrow( + /^Client does not support/, + ); +}); + +test("should respect server notification capabilities", async () => { + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + capabilities: { + logging: {}, + resources: { + listChanged: true, + }, + }, + }, + ); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: {}, + }, + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + // These should work because the server has the corresponding capabilities + await expect( + server.sendLoggingMessage({ level: "info", data: "Test" }), + ).resolves.not.toThrow(); + await expect(server.sendResourceListChanged()).resolves.not.toThrow(); + + // This should throw because the server doesn't have the tools capability + await expect(server.sendToolListChanged()).rejects.toThrow( + "Server does not support notifying of tool list changes", + ); +}); + +test("should only allow setRequestHandler for declared capabilities", () => { + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + sampling: {}, + }, + }, + ); + + // This should work because sampling is a declared capability + expect(() => { + client.setRequestHandler(CreateMessageRequestSchema, () => ({ + model: "test-model", + role: "assistant", + content: { + type: "text", + text: "Test response", + }, + })); + }).not.toThrow(); + + // This should throw because roots listing is not a declared capability + expect(() => { + client.setRequestHandler(ListRootsRequestSchema, () => ({})); + }).toThrow("Client does not support roots capability"); +}); /* -Test that custom request/notification/result schemas can be used with the Client class. -*/ + Test that custom request/notification/result schemas can be used with the Client class. + */ test("should typecheck", () => { const GetWeatherRequestSchema = RequestSchema.extend({ method: z.literal("weather/get"), @@ -50,10 +403,17 @@ test("should typecheck", () => { WeatherRequest, WeatherNotification, WeatherResult - >({ - name: "WeatherClient", - version: "1.0.0", - }); + >( + { + name: "WeatherClient", + version: "1.0.0", + }, + { + capabilities: { + sampling: {}, + }, + }, + ); // Typecheck that only valid weather requests/notifications/results are allowed false && diff --git a/src/client/index.ts b/src/client/index.ts index c3662b2..e0df322 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -1,8 +1,13 @@ -import { ProgressCallback, Protocol } from "../shared/protocol.js"; +import { + ProgressCallback, + Protocol, + ProtocolOptions, +} from "../shared/protocol.js"; import { Transport } from "../shared/transport.js"; import { CallToolRequest, CallToolResultSchema, + ClientCapabilities, ClientNotification, ClientRequest, ClientResult, @@ -32,9 +37,16 @@ import { ServerCapabilities, SubscribeRequest, SUPPORTED_PROTOCOL_VERSIONS, - UnsubscribeRequest + UnsubscribeRequest, } from "../types.js"; +export type ClientOptions = ProtocolOptions & { + /** + * Capabilities to advertise as being supported by this client. + */ + capabilities: ClientCapabilities; +}; + /** * An MCP client on top of a pluggable transport. * @@ -71,45 +83,67 @@ export class Client< > { private _serverCapabilities?: ServerCapabilities; private _serverVersion?: Implementation; + private _capabilities: ClientCapabilities; /** * Initializes this client with the given name and version information. */ - constructor(private _clientInfo: Implementation) { - super(); + constructor( + private _clientInfo: Implementation, + options: ClientOptions, + ) { + super(options); + this._capabilities = options.capabilities; + } + + protected assertCapability( + capability: keyof ServerCapabilities, + method: string, + ): void { + if (!this._serverCapabilities?.[capability]) { + throw new Error( + `Server does not support ${capability} (required for ${method})`, + ); + } } override async connect(transport: Transport): Promise { await super.connect(transport); - const result = await this.request( - { - method: "initialize", - params: { - protocolVersion: LATEST_PROTOCOL_VERSION, - capabilities: {}, - clientInfo: this._clientInfo, + try { + const result = await this.request( + { + method: "initialize", + params: { + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: this._capabilities, + clientInfo: this._clientInfo, + }, }, - }, - InitializeResultSchema, - ); + InitializeResultSchema, + ); - if (result === undefined) { - throw new Error(`Server sent invalid initialize result: ${result}`); - } + if (result === undefined) { + throw new Error(`Server sent invalid initialize result: ${result}`); + } - if (!SUPPORTED_PROTOCOL_VERSIONS.includes(result.protocolVersion)) { - throw new Error( - `Server's protocol version is not supported: ${result.protocolVersion}`, - ); - } + if (!SUPPORTED_PROTOCOL_VERSIONS.includes(result.protocolVersion)) { + throw new Error( + `Server's protocol version is not supported: ${result.protocolVersion}`, + ); + } - this._serverCapabilities = result.capabilities; - this._serverVersion = result.serverInfo; + this._serverCapabilities = result.capabilities; + this._serverVersion = result.serverInfo; - await this.notification({ - method: "notifications/initialized", - }); + await this.notification({ + method: "notifications/initialized", + }); + } catch (error) { + // Disconnect if initialization fails. + void this.close(); + throw error; + } } /** @@ -126,6 +160,120 @@ export class Client< return this._serverVersion; } + protected assertCapabilityForMethod(method: RequestT["method"]): void { + switch (method as ClientRequest["method"]) { + case "logging/setLevel": + if (!this._serverCapabilities?.logging) { + throw new Error( + `Server does not support logging (required for ${method})`, + ); + } + break; + + case "prompts/get": + case "prompts/list": + if (!this._serverCapabilities?.prompts) { + throw new Error( + `Server does not support prompts (required for ${method})`, + ); + } + break; + + case "resources/list": + case "resources/templates/list": + case "resources/read": + case "resources/subscribe": + case "resources/unsubscribe": + if (!this._serverCapabilities?.resources) { + throw new Error( + `Server does not support resources (required for ${method})`, + ); + } + + if ( + method === "resources/subscribe" && + !this._serverCapabilities.resources.subscribe + ) { + throw new Error( + `Server does not support resource subscriptions (required for ${method})`, + ); + } + + break; + + case "tools/call": + case "tools/list": + if (!this._serverCapabilities?.tools) { + throw new Error( + `Server does not support tools (required for ${method})`, + ); + } + break; + + case "completion/complete": + if (!this._serverCapabilities?.prompts) { + throw new Error( + `Server does not support prompts (required for ${method})`, + ); + } + break; + + case "initialize": + // No specific capability required for initialize + break; + + case "ping": + // No specific capability required for ping + break; + } + } + + protected assertNotificationCapability( + method: NotificationT["method"], + ): void { + switch (method as ClientNotification["method"]) { + case "notifications/roots/list_changed": + if (!this._capabilities.roots?.listChanged) { + throw new Error( + `Client does not support roots list changed notifications (required for ${method})`, + ); + } + break; + + case "notifications/initialized": + // No specific capability required for initialized + break; + + case "notifications/progress": + // Progress notifications are always allowed + break; + } + } + + protected assertRequestHandlerCapability(method: string): void { + switch (method) { + case "sampling/createMessage": + if (!this._capabilities.sampling) { + throw new Error( + `Client does not support sampling capability (required for ${method})`, + ); + } + break; + + case "roots/list": + if (!this._capabilities.roots) { + throw new Error( + `Client does not support roots capability (required for ${method})`, + ); + } + break; + + case "ping": + // No specific capability required for ping + break; + } + } + async ping() { return this.request({ method: "ping" }, EmptyResultSchema); } @@ -219,7 +367,9 @@ export class Client< async callTool( params: CallToolRequest["params"], - resultSchema: typeof CallToolResultSchema | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, + resultSchema: + | typeof CallToolResultSchema + | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, onprogress?: ProgressCallback, ) { return this.request( diff --git a/src/inMemory.test.ts b/src/inMemory.test.ts new file mode 100644 index 0000000..f7e9e97 --- /dev/null +++ b/src/inMemory.test.ts @@ -0,0 +1,94 @@ +import { InMemoryTransport } from "./inMemory.js"; +import { JSONRPCMessage } from "./types.js"; + +describe("InMemoryTransport", () => { + let clientTransport: InMemoryTransport; + let serverTransport: InMemoryTransport; + + beforeEach(() => { + [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + }); + + test("should create linked pair", () => { + expect(clientTransport).toBeDefined(); + expect(serverTransport).toBeDefined(); + }); + + test("should start without error", async () => { + await expect(clientTransport.start()).resolves.not.toThrow(); + await expect(serverTransport.start()).resolves.not.toThrow(); + }); + + test("should send message from client to server", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + id: 1, + }; + + let receivedMessage: JSONRPCMessage | undefined; + serverTransport.onmessage = (msg) => { + receivedMessage = msg; + }; + + await clientTransport.send(message); + expect(receivedMessage).toEqual(message); + }); + + test("should send message from server to client", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + id: 1, + }; + + let receivedMessage: JSONRPCMessage | undefined; + clientTransport.onmessage = (msg) => { + receivedMessage = msg; + }; + + await serverTransport.send(message); + expect(receivedMessage).toEqual(message); + }); + + test("should handle close", async () => { + let clientClosed = false; + let serverClosed = false; + + clientTransport.onclose = () => { + clientClosed = true; + }; + + serverTransport.onclose = () => { + serverClosed = true; + }; + + await clientTransport.close(); + expect(clientClosed).toBe(true); + expect(serverClosed).toBe(true); + }); + + test("should throw error when sending after close", async () => { + await clientTransport.close(); + await expect( + clientTransport.send({ jsonrpc: "2.0", method: "test", id: 1 }), + ).rejects.toThrow("Not connected"); + }); + + test("should queue messages sent before start", async () => { + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + id: 1, + }; + + let receivedMessage: JSONRPCMessage | undefined; + serverTransport.onmessage = (msg) => { + receivedMessage = msg; + }; + + await clientTransport.send(message); + await serverTransport.start(); + expect(receivedMessage).toEqual(message); + }); +}); diff --git a/src/inMemory.ts b/src/inMemory.ts new file mode 100644 index 0000000..2763f38 --- /dev/null +++ b/src/inMemory.ts @@ -0,0 +1,54 @@ +import { Transport } from "./shared/transport.js"; +import { JSONRPCMessage } from "./types.js"; + +/** + * In-memory transport for creating clients and servers that talk to each other within the same process. + */ +export class InMemoryTransport implements Transport { + private _otherTransport?: InMemoryTransport; + private _messageQueue: JSONRPCMessage[] = []; + + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: (message: JSONRPCMessage) => void; + + /** + * Creates a pair of linked in-memory transports that can communicate with each other. One should be passed to a Client and one to a Server. + */ + static createLinkedPair(): [InMemoryTransport, InMemoryTransport] { + const clientTransport = new InMemoryTransport(); + const serverTransport = new InMemoryTransport(); + clientTransport._otherTransport = serverTransport; + serverTransport._otherTransport = clientTransport; + return [clientTransport, serverTransport]; + } + + async start(): Promise { + // Process any messages that were queued before start was called + while (this._messageQueue.length > 0) { + const message = this._messageQueue.shift(); + if (message) { + this.onmessage?.(message); + } + } + } + + async close(): Promise { + const other = this._otherTransport; + this._otherTransport = undefined; + await other?.close(); + this.onclose?.(); + } + + async send(message: JSONRPCMessage): Promise { + if (!this._otherTransport) { + throw new Error("Not connected"); + } + + if (this._otherTransport.onmessage) { + this._otherTransport.onmessage(message); + } else { + this._otherTransport._messageQueue.push(message); + } + } +} diff --git a/src/server/index.test.ts b/src/server/index.test.ts index be33c58..d30c670 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -3,11 +3,338 @@ /* eslint-disable @typescript-eslint/no-unused-expressions */ import { Server } from "./index.js"; import { z } from "zod"; -import { RequestSchema, NotificationSchema, ResultSchema } from "../types.js"; +import { + RequestSchema, + NotificationSchema, + ResultSchema, + LATEST_PROTOCOL_VERSION, + SUPPORTED_PROTOCOL_VERSIONS, + CreateMessageRequestSchema, + ListPromptsRequestSchema, + ListResourcesRequestSchema, + ListToolsRequestSchema, + SetLevelRequestSchema, +} from "../types.js"; +import { Transport } from "../shared/transport.js"; +import { InMemoryTransport } from "../inMemory.js"; +import { Client } from "../client/index.js"; + +test("should accept latest protocol version", async () => { + let sendPromiseResolve: (value: unknown) => void; + const sendPromise = new Promise((resolve) => { + sendPromiseResolve = resolve; + }); + + const serverTransport: Transport = { + start: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + send: jest.fn().mockImplementation((message) => { + if (message.id === 1 && message.result) { + expect(message.result).toEqual({ + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: expect.any(Object), + serverInfo: { + name: "test server", + version: "1.0", + }, + }); + sendPromiseResolve(undefined); + } + return Promise.resolve(); + }), + }; + + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {}, + }, + }, + ); + + await server.connect(serverTransport); + + // Simulate initialize request with latest version + serverTransport.onmessage?.({ + jsonrpc: "2.0", + id: 1, + method: "initialize", + params: { + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: {}, + clientInfo: { + name: "test client", + version: "1.0", + }, + }, + }); + + await expect(sendPromise).resolves.toBeUndefined(); +}); + +test("should accept supported older protocol version", async () => { + const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1]; + let sendPromiseResolve: (value: unknown) => void; + const sendPromise = new Promise((resolve) => { + sendPromiseResolve = resolve; + }); + + const serverTransport: Transport = { + start: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + send: jest.fn().mockImplementation((message) => { + if (message.id === 1 && message.result) { + expect(message.result).toEqual({ + protocolVersion: OLD_VERSION, + capabilities: expect.any(Object), + serverInfo: { + name: "test server", + version: "1.0", + }, + }); + sendPromiseResolve(undefined); + } + return Promise.resolve(); + }), + }; + + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {}, + }, + }, + ); + + await server.connect(serverTransport); + + // Simulate initialize request with older version + serverTransport.onmessage?.({ + jsonrpc: "2.0", + id: 1, + method: "initialize", + params: { + protocolVersion: OLD_VERSION, + capabilities: {}, + clientInfo: { + name: "test client", + version: "1.0", + }, + }, + }); + + await expect(sendPromise).resolves.toBeUndefined(); +}); + +test("should handle unsupported protocol version", async () => { + let sendPromiseResolve: (value: unknown) => void; + const sendPromise = new Promise((resolve) => { + sendPromiseResolve = resolve; + }); + + const serverTransport: Transport = { + start: jest.fn().mockResolvedValue(undefined), + close: jest.fn().mockResolvedValue(undefined), + send: jest.fn().mockImplementation((message) => { + if (message.id === 1 && message.result) { + expect(message.result).toEqual({ + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: expect.any(Object), + serverInfo: { + name: "test server", + version: "1.0", + }, + }); + sendPromiseResolve(undefined); + } + return Promise.resolve(); + }), + }; + + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {}, + }, + }, + ); + + await server.connect(serverTransport); + + // Simulate initialize request with unsupported version + serverTransport.onmessage?.({ + jsonrpc: "2.0", + id: 1, + method: "initialize", + params: { + protocolVersion: "invalid-version", + capabilities: {}, + clientInfo: { + name: "test client", + version: "1.0", + }, + }, + }); + + await expect(sendPromise).resolves.toBeUndefined(); +}); + +test("should respect client capabilities", async () => { + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {}, + }, + enforceStrictCapabilities: true, + }, + ); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + { + capabilities: { + sampling: {}, + }, + }, + ); + + // Implement request handler for sampling/createMessage + client.setRequestHandler(CreateMessageRequestSchema, async (request) => { + // Mock implementation of createMessage + return { + model: "test-model", + role: "assistant", + content: { + type: "text", + text: "This is a test response", + }, + }; + }); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + expect(server.getClientCapabilities()).toEqual({ sampling: {} }); + + // This should work because sampling is supported by the client + await expect( + server.createMessage({ + messages: [], + maxTokens: 10, + }), + ).resolves.not.toThrow(); + + // This should still throw because roots are not supported by the client + await expect(server.listRoots()).rejects.toThrow(/^Client does not support/); +}); + +test("should respect server notification capabilities", async () => { + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + capabilities: { + logging: {}, + }, + enforceStrictCapabilities: true, + }, + ); + + const [clientTransport, serverTransport] = + InMemoryTransport.createLinkedPair(); + + await server.connect(serverTransport); + + // This should work because logging is supported by the server + await expect( + server.sendLoggingMessage({ + level: "info", + data: "Test log message", + }), + ).resolves.not.toThrow(); + + // This should throw because resource notificaitons are not supported by the server + await expect( + server.sendResourceUpdated({ uri: "test://resource" }), + ).rejects.toThrow(/^Server does not support/); +}); + +test("should only allow setRequestHandler for declared capabilities", () => { + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + resources: {}, + }, + }, + ); + + // These should work because the capabilities are declared + expect(() => { + server.setRequestHandler(ListPromptsRequestSchema, () => ({ prompts: [] })); + }).not.toThrow(); + + expect(() => { + server.setRequestHandler(ListResourcesRequestSchema, () => ({ + resources: [], + })); + }).not.toThrow(); + + // These should throw because the capabilities are not declared + expect(() => { + server.setRequestHandler(ListToolsRequestSchema, () => ({ tools: [] })); + }).toThrow(/^Server does not support tools/); + + expect(() => { + server.setRequestHandler(SetLevelRequestSchema, () => ({})); + }).toThrow(/^Server does not support logging/); +}); /* -Test that custom request/notification/result schemas can be used with the Server class. -*/ + Test that custom request/notification/result schemas can be used with the Server class. + */ test("should typecheck", () => { const GetWeatherRequestSchema = RequestSchema.extend({ method: z.literal("weather/get"), @@ -50,10 +377,20 @@ test("should typecheck", () => { WeatherRequest, WeatherNotification, WeatherResult - >({ - name: "WeatherServer", - version: "1.0.0", - }); + >( + { + name: "WeatherServer", + version: "1.0.0", + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {}, + }, + }, + ); // Typecheck that only valid weather requests/notifications/results are allowed weatherServer.setRequestHandler(GetWeatherRequestSchema, (request) => { diff --git a/src/server/index.ts b/src/server/index.ts index 8cf2a91..ecb525b 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -1,4 +1,8 @@ -import { ProgressCallback, Protocol } from "../shared/protocol.js"; +import { + ProgressCallback, + Protocol, + ProtocolOptions, +} from "../shared/protocol.js"; import { ClientCapabilities, CreateMessageRequest, @@ -10,11 +14,8 @@ import { InitializeRequestSchema, InitializeResult, LATEST_PROTOCOL_VERSION, - ListPromptsRequestSchema, - ListResourcesRequestSchema, ListRootsRequest, ListRootsResultSchema, - ListToolsRequestSchema, LoggingMessageNotification, Notification, Request, @@ -24,10 +25,16 @@ import { ServerNotification, ServerRequest, ServerResult, - SetLevelRequestSchema, - SUPPORTED_PROTOCOL_VERSIONS + SUPPORTED_PROTOCOL_VERSIONS, } from "../types.js"; +export type ServerOptions = ProtocolOptions & { + /** + * Capabilities to advertise as being supported by this server. + */ + capabilities: ServerCapabilities; +}; + /** * An MCP server on top of a pluggable transport. * @@ -64,6 +71,7 @@ export class Server< > { private _clientCapabilities?: ClientCapabilities; private _clientVersion?: Implementation; + private _capabilities: ServerCapabilities; /** * Callback for when initialization has fully completed (i.e., the client has sent an `initialized` notification). @@ -73,8 +81,12 @@ export class Server< /** * Initializes this server with the given name and version information. */ - constructor(private _serverInfo: Implementation) { - super(); + constructor( + private _serverInfo: Implementation, + options: ServerOptions, + ) { + super(options); + this._capabilities = options.capabilities; this.setRequestHandler(InitializeRequestSchema, (request) => this._oninitialize(request), @@ -84,6 +96,126 @@ export class Server< ); } + protected assertCapabilityForMethod(method: RequestT["method"]): void { + switch (method as ServerRequest["method"]) { + case "sampling/createMessage": + if (!this._clientCapabilities?.sampling) { + throw new Error( + `Client does not support sampling (required for ${method})`, + ); + } + break; + + case "roots/list": + if (!this._clientCapabilities?.roots) { + throw new Error( + `Client does not support listing roots (required for ${method})`, + ); + } + break; + + case "ping": + // No specific capability required for ping + break; + } + } + + protected assertNotificationCapability( + method: (ServerNotification | NotificationT)["method"], + ): void { + switch (method as ServerNotification["method"]) { + case "notifications/message": + if (!this._capabilities.logging) { + throw new Error( + `Server does not support logging (required for ${method})`, + ); + } + break; + + case "notifications/resources/updated": + case "notifications/resources/list_changed": + if (!this._capabilities.resources) { + throw new Error( + `Server does not support notifying about resources (required for ${method})`, + ); + } + break; + + case "notifications/tools/list_changed": + if (!this._capabilities.tools) { + throw new Error( + `Server does not support notifying of tool list changes (required for ${method})`, + ); + } + break; + + case "notifications/prompts/list_changed": + if (!this._capabilities.prompts) { + throw new Error( + `Server does not support notifying of prompt list changes (required for ${method})`, + ); + } + break; + + case "notifications/progress": + // Progress notifications are always allowed + break; + } + } + + protected assertRequestHandlerCapability(method: string): void { + switch (method) { + case "sampling/createMessage": + if (!this._capabilities.sampling) { + throw new Error( + `Server does not support sampling (required for ${method})`, + ); + } + break; + + case "logging/setLevel": + if (!this._capabilities.logging) { + throw new Error( + `Server does not support logging (required for ${method})`, + ); + } + break; + + case "prompts/get": + case "prompts/list": + if (!this._capabilities.prompts) { + throw new Error( + `Server does not support prompts (required for ${method})`, + ); + } + break; + + case "resources/list": + case "resources/templates/list": + case "resources/read": + if (!this._capabilities.resources) { + throw new Error( + `Server does not support resources (required for ${method})`, + ); + } + break; + + case "tools/call": + case "tools/list": + if (!this._capabilities.tools) { + throw new Error( + `Server does not support tools (required for ${method})`, + ); + } + break; + + case "ping": + case "initialize": + // No specific capability required for these methods + break; + } + } + private async _oninitialize( request: InitializeRequest, ): Promise { @@ -93,7 +225,9 @@ export class Server< this._clientVersion = request.params.clientInfo; return { - protocolVersion: SUPPORTED_PROTOCOL_VERSIONS.includes(requestedVersion) ? requestedVersion : LATEST_PROTOCOL_VERSION, + protocolVersion: SUPPORTED_PROTOCOL_VERSIONS.includes(requestedVersion) + ? requestedVersion + : LATEST_PROTOCOL_VERSION, capabilities: this.getCapabilities(), serverInfo: this._serverInfo, }; @@ -114,28 +248,7 @@ export class Server< } private getCapabilities(): ServerCapabilities { - return { - prompts: this._requestHandlers.has( - ListPromptsRequestSchema.shape.method.value as string, - ) - ? {} - : undefined, - resources: this._requestHandlers.has( - ListResourcesRequestSchema.shape.method.value as string, - ) - ? {} - : undefined, - tools: this._requestHandlers.has( - ListToolsRequestSchema.shape.method.value as string, - ) - ? {} - : undefined, - logging: this._requestHandlers.has( - SetLevelRequestSchema.shape.method.value as string, - ) - ? {} - : undefined, - }; + return this._capabilities; } async ping() { diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 22d2503..85610a9 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -21,11 +21,25 @@ import { Transport } from "./transport.js"; */ export type ProgressCallback = (progress: Progress) => void; +/** + * Additional initialization options. + */ +export type ProtocolOptions = { + /** + * Whether to restrict emitted requests to only those that the remote side has indicated that they can handle, through their advertised capabilities. + * + * Note that this DOES NOT affect checking of _local_ side capabilities, as it is considered a logic error to mis-specify those. + * + * Currently this defaults to false, for backwards compatibility with SDK versions that did not advertise capabilities correctly. In future, this will default to true. + */ + enforceStrictCapabilities?: boolean; +}; + /** * Implements MCP protocol framing on top of a pluggable transport, including * features like request/response linking, notifications, and progress. */ -export class Protocol< +export abstract class Protocol< SendRequestT extends Request, SendNotificationT extends Notification, SendResultT extends Result, @@ -70,7 +84,7 @@ export class Protocol< */ fallbackNotificationHandler?: (notification: Notification) => Promise; - constructor() { + constructor(private _options?: ProtocolOptions) { this.setNotificationHandler(ProgressNotificationSchema, (notification) => { this._onprogress(notification as unknown as ProgressNotification); }); @@ -245,6 +259,31 @@ export class Protocol< await this._transport?.close(); } + /** + * A method to check if a capability is supported by the remote side, for the given method to be called. + * + * This should be implemented by subclasses. + */ + protected abstract assertCapabilityForMethod( + method: SendRequestT["method"], + ): void; + + /** + * A method to check if a notification is supported by the local side, for the given method to be sent. + * + * This should be implemented by subclasses. + */ + protected abstract assertNotificationCapability( + method: SendNotificationT["method"], + ): void; + + /** + * A method to check if a request handler is supported by the local side, for the given method to be handled. + * + * This should be implemented by subclasses. + */ + protected abstract assertRequestHandlerCapability(method: string): void; + /** * Sends a request and wait for a response, with optional progress notifications in the meantime (if supported by the server). * @@ -261,6 +300,10 @@ export class Protocol< return; } + if (this._options?.enforceStrictCapabilities === true) { + this.assertCapabilityForMethod(request.method); + } + const messageId = this._requestMessageId++; const jsonrpcRequest: JSONRPCRequest = { ...request, @@ -301,6 +344,8 @@ export class Protocol< throw new Error("Not connected"); } + this.assertNotificationCapability(notification.method); + const jsonrpcNotification: JSONRPCNotification = { ...notification, jsonrpc: "2.0", @@ -322,7 +367,9 @@ export class Protocol< requestSchema: T, handler: (request: z.infer) => SendResultT | Promise, ): void { - this._requestHandlers.set(requestSchema.shape.method.value, (request) => + const method = requestSchema.shape.method.value; + this.assertRequestHandlerCapability(method); + this._requestHandlers.set(method, (request) => Promise.resolve(handler(requestSchema.parse(request))), ); } diff --git a/src/types.ts b/src/types.ts index 0ba6fa2..a0d2d80 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,7 +1,10 @@ import { z } from "zod"; export const LATEST_PROTOCOL_VERSION = "2024-11-05"; -export const SUPPORTED_PROTOCOL_VERSIONS = [LATEST_PROTOCOL_VERSION, "2024-10-07"]; +export const SUPPORTED_PROTOCOL_VERSIONS = [ + LATEST_PROTOCOL_VERSION, + "2024-10-07", +]; /* JSON-RPC types */ export const JSONRPC_VERSION = "2.0"; @@ -179,7 +182,7 @@ export const ClientCapabilitiesSchema = z z .object({ /** - * Whether the client supports notifications for changes to the roots list. + * Whether the client supports issuing notifications for changes to the roots list. */ listChanged: z.optional(z.boolean()), }) @@ -223,7 +226,7 @@ export const ServerCapabilitiesSchema = z z .object({ /** - * Whether this server supports notifications for changes to the prompt list. + * Whether this server supports issuing notifications for changes to the prompt list. */ listChanged: z.optional(z.boolean()), }) @@ -236,11 +239,12 @@ export const ServerCapabilitiesSchema = z z .object({ /** - * Whether this server supports subscribing to resource updates. + * Whether this server supports clients subscribing to resource updates. */ subscribe: z.optional(z.boolean()), + /** - * Whether this server supports notifications for changes to the resource list. + * Whether this server supports issuing notifications for changes to the resource list. */ listChanged: z.optional(z.boolean()), }) @@ -253,7 +257,7 @@ export const ServerCapabilitiesSchema = z z .object({ /** - * Whether this server supports notifications for changes to the tool list. + * Whether this server supports issuing notifications for changes to the tool list. */ listChanged: z.optional(z.boolean()), }) @@ -725,9 +729,11 @@ export const CallToolResultSchema = ResultSchema.extend({ /** * CallToolResultSchema extended with backwards compatibility to protocol version 2024-10-07. */ -export const CompatibilityCallToolResultSchema = CallToolResultSchema.or(ResultSchema.extend({ - toolResult: z.unknown(), -})); +export const CompatibilityCallToolResultSchema = CallToolResultSchema.or( + ResultSchema.extend({ + toolResult: z.unknown(), + }), +); /** * Used by the client to invoke a tool provided by the server. @@ -802,12 +808,14 @@ export const LoggingMessageNotificationSchema = NotificationSchema.extend({ /** * Hints to use for model selection. */ -export const ModelHintSchema = z.object({ - /** - * A hint for a model name. - */ - name: z.string().optional(), -}).passthrough(); +export const ModelHintSchema = z + .object({ + /** + * A hint for a model name. + */ + name: z.string().optional(), + }) + .passthrough(); /** * The server's preferences for model selection, requested of the client during sampling. @@ -886,7 +894,9 @@ export const CreateMessageResultSchema = ResultSchema.extend({ /** * The reason why sampling stopped. */ - stopReason: z.optional(z.enum(["endTurn", "stopSequence", "maxTokens"]).or(z.string())), + stopReason: z.optional( + z.enum(["endTurn", "stopSequence", "maxTokens"]).or(z.string()), + ), role: z.enum(["user", "assistant"]), content: z.discriminatedUnion("type", [ TextContentSchema, @@ -1156,7 +1166,9 @@ export type Tool = z.infer; export type ListToolsRequest = z.infer; export type ListToolsResult = z.infer; export type CallToolResult = z.infer; -export type CompatibilityCallToolResult = z.infer; +export type CompatibilityCallToolResult = z.infer< + typeof CompatibilityCallToolResultSchema +>; export type CallToolRequest = z.infer; export type ToolListChangedNotification = z.infer< typeof ToolListChangedNotificationSchema