Skip to content

Commit

Permalink
feat(ai): TensorFlow框架增加psNode和workerNode参数 (#1398)
Browse files Browse the repository at this point in the history
### 1.TensorFlow框架增加psNode和workerNode参数

![image](https://github.com/user-attachments/assets/3dc831ef-cde1-4b46-be44-ab19710a3281)
### 2.修改了分布式训练框架的位置,这样从单节点增加到多节点时,不用再去表单上面选择框架

![image](https://github.com/user-attachments/assets/71efca5c-3164-449b-b9c9-6aba8e81b745)
### 3.因增加psNode和workerNode参数,相应地修改了再次提交作业
  • Loading branch information
OYX-1 authored Aug 15, 2024
1 parent 83df60b commit 3d36aa0
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 16 deletions.
5 changes: 5 additions & 0 deletions .changeset/tidy-pianos-reply.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@scow/ai": patch
---

TensorFlow 增加 psNode 和 workerNode 参数
96 changes: 81 additions & 15 deletions apps/ai/src/app/(auth)/jobs/[clusterId]/LaunchAppForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ interface FixedFormFields {
account: string;
maxTime: number;
command?: string;
// TensorFlow特有参数
psNodes?: number;
workerNodes?: number;
}

interface CustomFormFields {
Expand Down Expand Up @@ -243,6 +246,8 @@ export const LaunchAppForm = (props: Props) => {

const gpuCount = Form.useWatch("gpuCount", form)!;

const framework = Form.useWatch("framework", form);

const memorySize = (currentPartitionInfo ?
currentPartitionInfo.gpus ? nodeCount * gpuCount
* Math.floor(currentPartitionInfo.cores / currentPartitionInfo.gpus)
Expand Down Expand Up @@ -478,6 +483,8 @@ export const LaunchAppForm = (props: Props) => {
const customAttributes = "customAttributes" in inputParams ? inputParams.customAttributes : {};
const command = "command" in inputParams ? inputParams.command : undefined;
const framework = "framework" in inputParams ? inputParams.framework : undefined;
const psNodes = "psNodes" in inputParams ? inputParams.psNodes : undefined;
const workerNodes = "workerNodes" in inputParams ? inputParams.workerNodes : undefined;
form.setFieldsValue({
mountPoints,
customFields: {
Expand All @@ -493,6 +500,8 @@ export const LaunchAppForm = (props: Props) => {
maxTime,
appJobName: genAppJobName(appName ?? "trainJobs"),
command,
psNodes,
workerNodes,
});
}

Expand All @@ -518,17 +527,26 @@ export const LaunchAppForm = (props: Props) => {
},
});

const handleFormChange = (changedValues: Partial<FormFields>, allValues: FormFields) => {
const { psNodes, workerNodes } = allValues;
if ("psNodes" in changedValues || "workerNodes" in changedValues) {
const newTotal = (psNodes || 0) + (workerNodes || 0);
form.setFieldsValue({ nodeCount: newTotal });
}
};

return (
<Form
form={form}
initialValues={{
... initialValues,
}}
onValuesChange={handleFormChange}
onFinish={async () => {

const { appJobName, algorithm, dataset, image, remoteImageUrl, framework, startCommand, model,
mountPoints, account, partition, coreCount,
gpuCount, maxTime, command, customFields } = await form.validateFields();
gpuCount, maxTime, command, customFields, psNodes, workerNodes } = await form.validateFields();

if (isTraining) {
await trainJobMutation.mutateAsync({
Expand All @@ -555,6 +573,8 @@ export const LaunchAppForm = (props: Props) => {
memory: memorySize,
command: command || "",
gpuType: currentPartitionInfo!.gpuType,
psNodes,
workerNodes,
});
} else {
let workingDirectory: string | undefined;
Expand Down Expand Up @@ -698,20 +718,6 @@ export const LaunchAppForm = (props: Props) => {
>
<Input placeholder="请输入远程镜像地址" />
</Form.Item>
{/* 分布式训练或者华为的卡训练,需要指定训练框架 */}
{(isTraining && (nodeCount > 1 || currentPartitionInfo?.gpuType === "huawei.com/Ascend910")) ? (
<>
{/* 手动选择算法框架,下拉框只有 tensorflow, pytorch */}
<Form.Item
label="分布式训练框架"
name="framework"
rules={[{ required: true }]}
>
<Select options={frameworkOptions}>
</Select>
</Form.Item>
</>
) : null}
{(customImage && !isTraining) ? (
<Form.Item
label="启动命令"
Expand Down Expand Up @@ -1023,8 +1029,54 @@ export const LaunchAppForm = (props: Props) => {
min={1}
max={isTraining ? undefined : 1}
{...inputNumberFloorConfig}
// framework是tensorflow且不是华为卡时 不允许手动改
disabled={isTraining && framework === "tensorflow"
&& (currentPartitionInfo ? currentPartitionInfo.gpuType !== "huawei.com/Ascend910" : true)}
/>
</Form.Item>
{/* tensorflow训练框架时,除了huawei.com/Ascend910的卡之外,都要区分PS node 和worker node */}
{(isTraining && framework === "tensorflow"
&& (currentPartitionInfo ? currentPartitionInfo.gpuType !== "huawei.com/Ascend910" : true)) ? (
<>
<Form.Item
label="PS节点数"
name="psNodes"
initialValue={1}
rules={[
{ required: true,
type: "integer",
},
]}
>
<InputNumber
defaultValue={1}
min={1}
{...inputNumberFloorConfig}
/>
</Form.Item>
</>
) : null}
{(isTraining && framework === "tensorflow"
&& (currentPartitionInfo ? currentPartitionInfo.gpuType !== "huawei.com/Ascend910" : true)) ? (
<>
<Form.Item
label="worker节点数"
name="workerNodes"
initialValue={nodeCount - 1}
rules={[
{ required: true,
type: "integer",
},
]}
>
<InputNumber
defaultValue={nodeCount - 1}
min={1}
{...inputNumberFloorConfig}
/>
</Form.Item>
</>
) : null}
{
currentPartitionInfo?.gpus ? (
<Form.Item
Expand Down Expand Up @@ -1080,6 +1132,20 @@ export const LaunchAppForm = (props: Props) => {
</Form.Item>
)
}
{/* 分布式训练或者华为的卡训练,需要指定训练框架 */}
{(isTraining && (nodeCount > 1 || currentPartitionInfo?.gpuType === "huawei.com/Ascend910")) ? (
<>
{/* 手动选择算法框架,下拉框只有 tensorflow, pytorch */}
<Form.Item
label="分布式训练框架"
name="framework"
rules={[{ required: true }]}
>
<Select options={frameworkOptions}>
</Select>
</Form.Item>
</>
) : null}
<Form.Item label="最长运行时间" name="maxTime" rules={[{ required: true }]}>
<InputNumber min={1} step={1} addonAfter="分钟" />
</Form.Item>
Expand Down
7 changes: 6 additions & 1 deletion apps/ai/src/server/trpc/route/jobs/jobs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ const TrainJobInputSchema = z.object({
maxTime: z.number(),
command: z.string(),
gpuType: z.string().optional(),
// TensorFlow特有参数
psNodes: z.number().optional(),
workerNodes: z.number().optional(),
});

export type TrainJobInput = z.infer<typeof TrainJobInputSchema>;
Expand Down Expand Up @@ -134,7 +137,7 @@ procedure
async ({ input, ctx: { user } }) => {
const { clusterId, trainJobName, isAlgorithmPrivate, algorithm, image, framework, remoteImageUrl,
isDatasetPrivate, dataset, isModelPrivate, model, mountPoints = [], account, partition,
coreCount, nodeCount, gpuCount, memory, maxTime, command, gpuType } = input;
coreCount, nodeCount, gpuCount, memory, maxTime, command, gpuType, psNodes, workerNodes } = input;
const userId = user.identityId;

const host = getClusterLoginNode(clusterId);
Expand Down Expand Up @@ -253,6 +256,8 @@ procedure
// 如果nodeCount不为1但同时选定镜像又没有框架标签,该接口会报错
(nodeCount === 1 && !gpuType?.startsWith("huawei.com")) ? "" : framework || "",
],
psNodeCount:psNodes,
workerNodeCount:workerNodes,
}).catch((e) => {
const ex = e as ServiceError;
throw new TRPCError({
Expand Down

0 comments on commit 3d36aa0

Please sign in to comment.