diff --git a/docs/reference/deploy/master-config-reference.rst b/docs/reference/deploy/master-config-reference.rst index dc474b3df4d..661b08461cf 100644 --- a/docs/reference/deploy/master-config-reference.rst +++ b/docs/reference/deploy/master-config-reference.rst @@ -1524,6 +1524,11 @@ Specifies configuration settings for SSH. Number of bits to use when generating RSA keys for SSH for tasks. Maximum size is 16384. +``key_type`` +============ + +Specifies the crypto system for SSH. Currently accepts ``RSA``, ``ECDSA`` or ``ED25519``. + ``authz`` ========= diff --git a/docs/release-notes/ssh-crypto-system.rst b/docs/release-notes/ssh-crypto-system.rst new file mode 100644 index 00000000000..acd54812832 --- /dev/null +++ b/docs/release-notes/ssh-crypto-system.rst @@ -0,0 +1,8 @@ +:orphan: + +**Improvements** + +- Master Configuration: Add support for crypto system configuration for ssh connection. + ``security.key_type`` now accepts ``RSA``, ``ECDSA`` or ``ED25519``. Default key type is changed + from ``1024-bit RSA`` to ``ED25519``, since ``ED25519`` keys are faster and more secure than the + old default, and ``ED25519`` is also the default key type for ``ssh-keygen``. diff --git a/harness/determined/cli/shell.py b/harness/determined/cli/shell.py index e9af72da874..cc3e211dff0 100644 --- a/harness/determined/cli/shell.py +++ b/harness/determined/cli/shell.py @@ -1,7 +1,6 @@ import argparse import contextlib import functools -import getpass import os import pathlib import platform @@ -22,9 +21,6 @@ def start_shell(args: argparse.Namespace) -> None: sess = cli.setup_session(args) - data = {} - if args.passphrase: - data["passphrase"] = getpass.getpass("Enter new passphrase: ") config = ntsc.parse_config(args.config_file, None, args.config, args.volume) workspace_id = cli.workspace.get_workspace_id_from_args(args) @@ -35,7 +31,6 @@ def start_shell(args: argparse.Namespace) -> None: args.template, context_path=args.context, includes=args.include, - data=data, workspace_id=workspace_id, ) shell = bindings.v1LaunchShellResponse.from_json(resp).shell @@ -280,12 +275,6 @@ def _open_shell( help=ntsc.INCLUDE_DESC, ), cli.Arg("--config", action="append", default=[], help=ntsc.CONFIG_DESC), - cli.Arg( - "-p", - "--passphrase", - action="store_true", - help="passphrase to encrypt the shell private key", - ), cli.Arg( "--template", type=str, diff --git a/master/internal/api_shell.go b/master/internal/api_shell.go index dbadba57147..4e744f1974c 100644 --- a/master/internal/api_shell.go +++ b/master/internal/api_shell.go @@ -3,7 +3,6 @@ package internal import ( "archive/tar" "context" - "encoding/json" "fmt" "strconv" @@ -253,20 +252,7 @@ func (a *apiServer) LaunchShell( } maps.Copy(launchReq.Spec.Base.ExtraEnvVars, oidcPachydermEnvVars) - var passphrase *string - if len(req.Data) > 0 { - var data map[string]interface{} - if err = json.Unmarshal(req.Data, &data); err != nil { - return nil, status.Errorf(codes.Internal, "failed to parse data %s: %s", req.Data, err) - } - if pwd, ok := data["passphrase"]; ok { - if typed, typedOK := pwd.(string); typedOK { - passphrase = &typed - } - } - } - - keys, err := ssh.GenerateKey(launchReq.Spec.Base.SSHRsaSize, passphrase) + keys, err := ssh.GenerateKey(launchReq.Spec.Base.SSHConfig) if err != nil { return nil, status.Error(codes.Internal, err.Error()) } diff --git a/master/internal/api_user_intg_test.go b/master/internal/api_user_intg_test.go index c0f1d673abf..3e55576a552 100644 --- a/master/internal/api_user_intg_test.go +++ b/master/internal/api_user_intg_test.go @@ -118,7 +118,7 @@ func setupAPITest(t *testing.T, pgdb *db.PgDB, TaskContainerDefaults: model.TaskContainerDefaultsConfig{}, ResourceConfig: *config.DefaultResourceConfig(), }, - taskSpec: &tasks.TaskSpec{SSHRsaSize: 1024}, + taskSpec: &tasks.TaskSpec{SSHConfig: config.SSHConfig{KeyType: "ED25519"}}, allRms: map[string]rm.ResourceManager{config.DefaultClusterName: mockRM}, }, } diff --git a/master/internal/config/config.go b/master/internal/config/config.go index 306cc1b27bf..f68f1cc6095 100644 --- a/master/internal/config/config.go +++ b/master/internal/config/config.go @@ -42,6 +42,15 @@ const ( preemptionScheduler = "preemption" ) +const ( + // KeyTypeRSA uses RSA. + KeyTypeRSA = "RSA" + // KeyTypeECDSA uses ECDSA. + KeyTypeECDSA = "ECDSA" + // KeyTypeED25519 uses ED25519. + KeyTypeED25519 = "ED25519" +) + type ( // ExperimentConfigPatch is the updatedble fields for patching an experiment. ExperimentConfigPatch struct { @@ -108,7 +117,7 @@ func DefaultConfig() *Config { Group: "root", }, SSH: SSHConfig{ - RsaKeySize: 1024, + KeyType: KeyTypeED25519, }, AuthZ: *DefaultAuthZConfig(), }, @@ -452,7 +461,8 @@ type SecurityConfig struct { // SSHConfig is the configuration setting for SSH. type SSHConfig struct { - RsaKeySize int `json:"rsa_key_size"` + RsaKeySize int `json:"rsa_key_size"` + KeyType string `json:"key_type"` } // TLSConfig is the configuration for setting up serving over TLS. @@ -475,10 +485,15 @@ func (t *TLSConfig) Validate() []error { // Validate implements the check.Validatable interface. func (t *SSHConfig) Validate() []error { var errs []error - if t.RsaKeySize < 1 { - errs = append(errs, errors.New("RSA Key size must be greater than 0")) - } else if t.RsaKeySize > 16384 { - errs = append(errs, errors.New("RSA Key size must be less than 16,384")) + if t.KeyType != KeyTypeRSA && t.KeyType != KeyTypeECDSA && t.KeyType != KeyTypeED25519 { + errs = append(errs, errors.New("Crypto system must be one of 'RSA', 'ECDSA' or 'ED25519'")) + } + if t.KeyType == KeyTypeRSA { + if t.RsaKeySize < 1 { + errs = append(errs, errors.New("RSA Key size must be greater than 0")) + } else if t.RsaKeySize > 16384 { + errs = append(errs, errors.New("RSA Key size must be less than 16,384")) + } } return errs } diff --git a/master/internal/core.go b/master/internal/core.go index 3b0bd270197..6640b8332d6 100644 --- a/master/internal/core.go +++ b/master/internal/core.go @@ -1242,7 +1242,7 @@ func (m *Master) Run(ctx context.Context, gRPCLogInitDone chan struct{}) error { HarnessPath: filepath.Join(m.config.Root, "wheels"), TaskContainerDefaults: m.config.TaskContainerDefaults, MasterCert: config.GetCertPEM(cert), - SSHRsaSize: m.config.Security.SSH.RsaKeySize, + SSHConfig: m.config.Security.SSH, SegmentEnabled: m.config.Telemetry.Enabled && m.config.Telemetry.SegmentMasterKey != "", SegmentAPIKey: m.config.Telemetry.SegmentMasterKey, LogRetentionDays: m.config.RetentionPolicy.LogRetentionDays, diff --git a/master/internal/core_intg_test.go b/master/internal/core_intg_test.go index d8386407c0f..4935cc38a05 100644 --- a/master/internal/core_intg_test.go +++ b/master/internal/core_intg_test.go @@ -121,7 +121,7 @@ func TestRun(t *testing.T) { DefaultLoggingConfig: &model.DefaultLoggingConfig{}, }, }, - taskSpec: &tasks.TaskSpec{SSHRsaSize: 1024}, + taskSpec: &tasks.TaskSpec{SSHConfig: config.SSHConfig{KeyType: "ED25519"}}, } require.NoError(t, m.config.Resolve()) m.config.DB = config.DBConfig{ diff --git a/master/internal/experiment.go b/master/internal/experiment.go index c980ccdb1e5..8d2d51c79d0 100644 --- a/master/internal/experiment.go +++ b/master/internal/experiment.go @@ -158,7 +158,7 @@ func newExperiment( taskSpec.AgentUserGroup = agentUserGroup - generatedKeys, err := ssh.GenerateKey(taskSpec.SSHRsaSize, nil) + generatedKeys, err := ssh.GenerateKey(taskSpec.SSHConfig) if err != nil { return nil, nil, errors.Wrap(err, "generating ssh keys for trials") } diff --git a/master/internal/trial_intg_test.go b/master/internal/trial_intg_test.go index 59b5b079749..5681a1065ab 100644 --- a/master/internal/trial_intg_test.go +++ b/master/internal/trial_intg_test.go @@ -172,7 +172,6 @@ func setup(t *testing.T) ( &model.Checkpoint{}, &tasks.TaskSpec{ AgentUserGroup: &model.AgentUserGroup{}, - SSHRsaSize: 1024, Workspace: model.DefaultWorkspaceName, }, ssh.PrivateAndPublicKeys{}, diff --git a/master/pkg/ssh/ssh.go b/master/pkg/ssh/ssh.go index 026a87523db..73a506b09e6 100644 --- a/master/pkg/ssh/ssh.go +++ b/master/pkg/ssh/ssh.go @@ -1,6 +1,9 @@ package ssh import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/x509" @@ -8,10 +11,13 @@ import ( "github.com/pkg/errors" sshlib "golang.org/x/crypto/ssh" + + "github.com/determined-ai/determined/master/internal/config" ) const ( - trialPEMBlockType = "RSA PRIVATE KEY" + rsaPEMBlockType = "RSA PRIVATE KEY" + ecdsaPEMBlockType = "EC PRIVATE KEY" ) // PrivateAndPublicKeys contains a private and public key. @@ -21,40 +27,93 @@ type PrivateAndPublicKeys struct { } // GenerateKey returns a private and public SSH key. -func GenerateKey(rsaKeySize int, passphrase *string) (PrivateAndPublicKeys, error) { +func GenerateKey(conf config.SSHConfig) (PrivateAndPublicKeys, error) { var generatedKeys PrivateAndPublicKeys + switch conf.KeyType { + case config.KeyTypeRSA: + return generateRSAKey(conf.RsaKeySize) + case config.KeyTypeECDSA: + return generateECDSAKey() + case config.KeyTypeED25519: + return generateED25519Key() + default: + return generatedKeys, errors.New("Invalid crypto system") + } +} + +func generateRSAKey(rsaKeySize int) (PrivateAndPublicKeys, error) { privateKey, err := rsa.GenerateKey(rand.Reader, rsaKeySize) if err != nil { - return generatedKeys, errors.Wrap(err, "unable to generate private key") + return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate RSA private key") } if err = privateKey.Validate(); err != nil { - return generatedKeys, err + return PrivateAndPublicKeys{}, err } block := &pem.Block{ - Type: trialPEMBlockType, + Type: rsaPEMBlockType, Bytes: x509.MarshalPKCS1PrivateKey(privateKey), } - if passphrase != nil { - // TODO: Replace usage of deprecated x509.EncryptPEMBlock. - block, err = x509.EncryptPEMBlock( //nolint: staticcheck - rand.Reader, block.Type, block.Bytes, []byte(*passphrase), x509.PEMCipherAES256) - if err != nil { - return generatedKeys, errors.Wrap(err, "unable to encrypt private key") - } + publicKey, err := sshlib.NewPublicKey(&privateKey.PublicKey) + if err != nil { + return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate RSA public key") + } + + return PrivateAndPublicKeys{ + PrivateKey: pem.EncodeToMemory(block), + PublicKey: sshlib.MarshalAuthorizedKey(publicKey), + }, nil +} + +func generateECDSAKey() (PrivateAndPublicKeys, error) { + // Curve size currently not configurable, using the NIST recommendation. + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate ECDSA private key") + } + + privateKeyBytes, err := x509.MarshalECPrivateKey(privateKey) + if err != nil { + return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to marshal ECDSA private key") + } + + block := &pem.Block{ + Type: ecdsaPEMBlockType, + Bytes: privateKeyBytes, } publicKey, err := sshlib.NewPublicKey(&privateKey.PublicKey) if err != nil { - return generatedKeys, errors.Wrap(err, "unable to generate public key") + return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate ECDSA public key") } - generatedKeys = PrivateAndPublicKeys{ + return PrivateAndPublicKeys{ PrivateKey: pem.EncodeToMemory(block), PublicKey: sshlib.MarshalAuthorizedKey(publicKey), + }, nil +} + +func generateED25519Key() (PrivateAndPublicKeys, error) { + ed25519PublicKey, privateKey, err := ed25519.GenerateKey(nil) + if err != nil { + return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate ED25519 private key") } - return generatedKeys, nil + // Before OpenSSH 9.6, for ED25519 keys, only the OpenSSH private key format was supported. + block, err := sshlib.MarshalPrivateKey(privateKey, "") + if err != nil { + return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to marshal ED25519 private key") + } + + publicKey, err := sshlib.NewPublicKey(ed25519PublicKey) + if err != nil { + return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate ED25519 public key") + } + + return PrivateAndPublicKeys{ + PrivateKey: pem.EncodeToMemory(block), + PublicKey: sshlib.MarshalAuthorizedKey(publicKey), + }, nil } diff --git a/master/pkg/ssh/ssh_test.go b/master/pkg/ssh/ssh_test.go new file mode 100644 index 00000000000..a5397ac1321 --- /dev/null +++ b/master/pkg/ssh/ssh_test.go @@ -0,0 +1,39 @@ +package ssh + +import ( + "testing" + + "golang.org/x/crypto/ssh" + "gotest.tools/assert" + + "github.com/determined-ai/determined/master/internal/config" +) + +func verifyKeys(t *testing.T, keys PrivateAndPublicKeys) { + privateKey, err := ssh.ParsePrivateKey(keys.PrivateKey) + assert.NilError(t, err) + + publickKey, _, _, _, err := ssh.ParseAuthorizedKey(keys.PublicKey) //nolint:dogsled + assert.NilError(t, err) + assert.Equal(t, string(publickKey.Marshal()), string(privateKey.PublicKey().Marshal())) +} + +func TestSSHKeyGenerate(t *testing.T) { + t.Run("generate RSA key", func(t *testing.T) { + keys, err := GenerateKey(config.SSHConfig{KeyType: config.KeyTypeRSA, RsaKeySize: 512}) + assert.NilError(t, err) + verifyKeys(t, keys) + }) + + t.Run("generate ECDSA key", func(t *testing.T) { + keys, err := GenerateKey(config.SSHConfig{KeyType: config.KeyTypeECDSA}) + assert.NilError(t, err) + verifyKeys(t, keys) + }) + + t.Run("generate ED25519 key", func(t *testing.T) { + keys, err := GenerateKey(config.SSHConfig{KeyType: config.KeyTypeED25519}) + assert.NilError(t, err) + verifyKeys(t, keys) + }) +} diff --git a/master/pkg/tasks/task.go b/master/pkg/tasks/task.go index 12b000c2de3..bc563197d66 100644 --- a/master/pkg/tasks/task.go +++ b/master/pkg/tasks/task.go @@ -12,6 +12,7 @@ import ( "github.com/docker/docker/api/types/mount" "github.com/jinzhu/copier" + "github.com/determined-ai/determined/master/internal/config" "github.com/determined-ai/determined/master/pkg/archive" "github.com/determined-ai/determined/master/pkg/cproto" "github.com/determined-ai/determined/master/pkg/device" @@ -70,7 +71,7 @@ type TaskSpec struct { ClusterID string HarnessPath string MasterCert []byte - SSHRsaSize int + SSHConfig config.SSHConfig SegmentEnabled bool SegmentAPIKey string diff --git a/proto/pkg/apiv1/shell.pb.go b/proto/pkg/apiv1/shell.pb.go index 8138e6f888b..e5d9c40b5d7 100644 --- a/proto/pkg/apiv1/shell.pb.go +++ b/proto/pkg/apiv1/shell.pb.go @@ -570,7 +570,7 @@ type LaunchShellRequest struct { TemplateName string `protobuf:"bytes,2,opt,name=template_name,json=templateName,proto3" json:"template_name,omitempty"` // The files to run with the command. Files []*utilv1.File `protobuf:"bytes,3,rep,name=files,proto3" json:"files,omitempty"` - // Additional data. + // Deprecated: Do not use. Data []byte `protobuf:"bytes,4,opt,name=data,proto3" json:"data,omitempty"` // Workspace ID. Defaults to 'Uncategorized' workspace if not specified. WorkspaceId int32 `protobuf:"varint,5,opt,name=workspace_id,json=workspaceId,proto3" json:"workspace_id,omitempty"` diff --git a/proto/src/determined/api/v1/shell.proto b/proto/src/determined/api/v1/shell.proto index b3a2c1fe7af..76504a078bf 100644 --- a/proto/src/determined/api/v1/shell.proto +++ b/proto/src/determined/api/v1/shell.proto @@ -102,7 +102,7 @@ message LaunchShellRequest { string template_name = 2; // The files to run with the command. repeated determined.util.v1.File files = 3; - // Additional data. + // Deprecated: Do not use. bytes data = 4; // Workspace ID. Defaults to 'Uncategorized' workspace if not specified. int32 workspace_id = 5; diff --git a/webui/react/src/services/api-ts-sdk/api.ts b/webui/react/src/services/api-ts-sdk/api.ts index 2329eb9787d..a92ff0b34b5 100644 --- a/webui/react/src/services/api-ts-sdk/api.ts +++ b/webui/react/src/services/api-ts-sdk/api.ts @@ -6154,7 +6154,7 @@ export interface V1LaunchShellRequest { */ files?: Array; /** - * Additional data. + * Deprecated: Do not use. * @type {string} * @memberof V1LaunchShellRequest */