From 3d36aa035ad4b07a5a24f9b20cb3e50e433fad5d Mon Sep 17 00:00:00 2001 From: OYX-1 <74037789+OYX-1@users.noreply.github.com> Date: Thu, 15 Aug 2024 20:44:24 +0800 Subject: [PATCH] =?UTF-8?q?feat(ai):=20TensorFlow=E6=A1=86=E6=9E=B6?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0psNode=E5=92=8CworkerNode=E5=8F=82=E6=95=B0?= =?UTF-8?q?=20(#1398)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### 1.TensorFlow框架增加psNode和workerNode参数 ![image](https://github.com/user-attachments/assets/3dc831ef-cde1-4b46-be44-ab19710a3281) ### 2.修改了分布式训练框架的位置,这样从单节点增加到多节点时,不用再去表单上面选择框架 ![image](https://github.com/user-attachments/assets/71efca5c-3164-449b-b9c9-6aba8e81b745) ### 3.因增加psNode和workerNode参数,相应地修改了再次提交作业 --- .changeset/tidy-pianos-reply.md | 5 + .../(auth)/jobs/[clusterId]/LaunchAppForm.tsx | 96 ++++++++++++++++--- apps/ai/src/server/trpc/route/jobs/jobs.ts | 7 +- 3 files changed, 92 insertions(+), 16 deletions(-) create mode 100644 .changeset/tidy-pianos-reply.md diff --git a/.changeset/tidy-pianos-reply.md b/.changeset/tidy-pianos-reply.md new file mode 100644 index 0000000000..f22248060d --- /dev/null +++ b/.changeset/tidy-pianos-reply.md @@ -0,0 +1,5 @@ +--- +"@scow/ai": patch +--- + +TensorFlow 增加 psNode 和 workerNode 参数 diff --git a/apps/ai/src/app/(auth)/jobs/[clusterId]/LaunchAppForm.tsx b/apps/ai/src/app/(auth)/jobs/[clusterId]/LaunchAppForm.tsx index ea19affa7b..37d0b23e83 100644 --- a/apps/ai/src/app/(auth)/jobs/[clusterId]/LaunchAppForm.tsx +++ b/apps/ai/src/app/(auth)/jobs/[clusterId]/LaunchAppForm.tsx @@ -73,6 +73,9 @@ interface FixedFormFields { account: string; maxTime: number; command?: string; + // TensorFlow特有参数 + psNodes?: number; + workerNodes?: number; } interface CustomFormFields { @@ -243,6 +246,8 @@ export const LaunchAppForm = (props: Props) => { const gpuCount = Form.useWatch("gpuCount", form)!; + const framework = Form.useWatch("framework", form); + const memorySize = (currentPartitionInfo ? currentPartitionInfo.gpus ? nodeCount * gpuCount * Math.floor(currentPartitionInfo.cores / currentPartitionInfo.gpus) @@ -478,6 +483,8 @@ export const LaunchAppForm = (props: Props) => { const customAttributes = "customAttributes" in inputParams ? inputParams.customAttributes : {}; const command = "command" in inputParams ? inputParams.command : undefined; const framework = "framework" in inputParams ? inputParams.framework : undefined; + const psNodes = "psNodes" in inputParams ? inputParams.psNodes : undefined; + const workerNodes = "workerNodes" in inputParams ? inputParams.workerNodes : undefined; form.setFieldsValue({ mountPoints, customFields: { @@ -493,6 +500,8 @@ export const LaunchAppForm = (props: Props) => { maxTime, appJobName: genAppJobName(appName ?? "trainJobs"), command, + psNodes, + workerNodes, }); } @@ -518,17 +527,26 @@ export const LaunchAppForm = (props: Props) => { }, }); + const handleFormChange = (changedValues: Partial, allValues: FormFields) => { + const { psNodes, workerNodes } = allValues; + if ("psNodes" in changedValues || "workerNodes" in changedValues) { + const newTotal = (psNodes || 0) + (workerNodes || 0); + form.setFieldsValue({ nodeCount: newTotal }); + } + }; + return (
{ const { appJobName, algorithm, dataset, image, remoteImageUrl, framework, startCommand, model, mountPoints, account, partition, coreCount, - gpuCount, maxTime, command, customFields } = await form.validateFields(); + gpuCount, maxTime, command, customFields, psNodes, workerNodes } = await form.validateFields(); if (isTraining) { await trainJobMutation.mutateAsync({ @@ -555,6 +573,8 @@ export const LaunchAppForm = (props: Props) => { memory: memorySize, command: command || "", gpuType: currentPartitionInfo!.gpuType, + psNodes, + workerNodes, }); } else { let workingDirectory: string | undefined; @@ -698,20 +718,6 @@ export const LaunchAppForm = (props: Props) => { > - {/* 分布式训练或者华为的卡训练,需要指定训练框架 */} - {(isTraining && (nodeCount > 1 || currentPartitionInfo?.gpuType === "huawei.com/Ascend910")) ? ( - <> - {/* 手动选择算法框架,下拉框只有 tensorflow, pytorch */} - - - - - ) : null} {(customImage && !isTraining) ? ( { min={1} max={isTraining ? undefined : 1} {...inputNumberFloorConfig} + // framework是tensorflow且不是华为卡时 不允许手动改 + disabled={isTraining && framework === "tensorflow" + && (currentPartitionInfo ? currentPartitionInfo.gpuType !== "huawei.com/Ascend910" : true)} /> + {/* tensorflow训练框架时,除了huawei.com/Ascend910的卡之外,都要区分PS node 和worker node */} + {(isTraining && framework === "tensorflow" + && (currentPartitionInfo ? currentPartitionInfo.gpuType !== "huawei.com/Ascend910" : true)) ? ( + <> + + + + + ) : null} + {(isTraining && framework === "tensorflow" + && (currentPartitionInfo ? currentPartitionInfo.gpuType !== "huawei.com/Ascend910" : true)) ? ( + <> + + + + + ) : null} { currentPartitionInfo?.gpus ? ( { ) } + {/* 分布式训练或者华为的卡训练,需要指定训练框架 */} + {(isTraining && (nodeCount > 1 || currentPartitionInfo?.gpuType === "huawei.com/Ascend910")) ? ( + <> + {/* 手动选择算法框架,下拉框只有 tensorflow, pytorch */} + + + + + ) : null} diff --git a/apps/ai/src/server/trpc/route/jobs/jobs.ts b/apps/ai/src/server/trpc/route/jobs/jobs.ts index 7ad22db527..04ab621d36 100644 --- a/apps/ai/src/server/trpc/route/jobs/jobs.ts +++ b/apps/ai/src/server/trpc/route/jobs/jobs.ts @@ -88,6 +88,9 @@ const TrainJobInputSchema = z.object({ maxTime: z.number(), command: z.string(), gpuType: z.string().optional(), + // TensorFlow特有参数 + psNodes: z.number().optional(), + workerNodes: z.number().optional(), }); export type TrainJobInput = z.infer; @@ -134,7 +137,7 @@ procedure async ({ input, ctx: { user } }) => { const { clusterId, trainJobName, isAlgorithmPrivate, algorithm, image, framework, remoteImageUrl, isDatasetPrivate, dataset, isModelPrivate, model, mountPoints = [], account, partition, - coreCount, nodeCount, gpuCount, memory, maxTime, command, gpuType } = input; + coreCount, nodeCount, gpuCount, memory, maxTime, command, gpuType, psNodes, workerNodes } = input; const userId = user.identityId; const host = getClusterLoginNode(clusterId); @@ -253,6 +256,8 @@ procedure // 如果nodeCount不为1但同时选定镜像又没有框架标签,该接口会报错 (nodeCount === 1 && !gpuType?.startsWith("huawei.com")) ? "" : framework || "", ], + psNodeCount:psNodes, + workerNodeCount:workerNodes, }).catch((e) => { const ex = e as ServiceError; throw new TRPCError({