Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow master configuration for ssh key crypto system #10072

Merged
merged 17 commits into from
Oct 24, 2024
5 changes: 5 additions & 0 deletions docs/reference/deploy/master-config-reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,11 @@ Specifies configuration settings for SSH.

Number of bits to use when generating RSA keys for SSH for tasks. Maximum size is 16384.

``key_type``
============

Specifies the crypto system for SSH. Currently accepts ``RSA``, ``ECDSA`` or ``ED25519``.

``authz``
=========

Expand Down
8 changes: 8 additions & 0 deletions docs/release-notes/ssh-crypto-system.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
:orphan:

**Improvements**

- Master Configuration: Add support for crypto system configuration for ssh connection.
``security.key_type`` now accepts ``RSA``, ``ECDSA`` or ``ED25519``. Default key type is changed
from ``1024-bit RSA`` to ``ED25519``, since ``ED25519`` keys are faster and more secure than the
old default, and ``ED25519`` is also the default key type for ``ssh-keygen``.
11 changes: 0 additions & 11 deletions harness/determined/cli/shell.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse
import contextlib
import functools
import getpass
import os
import pathlib
import platform
Expand All @@ -22,9 +21,6 @@

def start_shell(args: argparse.Namespace) -> None:
sess = cli.setup_session(args)
data = {}
if args.passphrase:
data["passphrase"] = getpass.getpass("Enter new passphrase: ")
config = ntsc.parse_config(args.config_file, None, args.config, args.volume)
workspace_id = cli.workspace.get_workspace_id_from_args(args)

Expand All @@ -35,7 +31,6 @@ def start_shell(args: argparse.Namespace) -> None:
args.template,
context_path=args.context,
includes=args.include,
data=data,
workspace_id=workspace_id,
)
shell = bindings.v1LaunchShellResponse.from_json(resp).shell
Expand Down Expand Up @@ -280,12 +275,6 @@ def _open_shell(
help=ntsc.INCLUDE_DESC,
),
cli.Arg("--config", action="append", default=[], help=ntsc.CONFIG_DESC),
cli.Arg(
"-p",
"--passphrase",
action="store_true",
help="passphrase to encrypt the shell private key",
),
cli.Arg(
"--template",
type=str,
Expand Down
16 changes: 1 addition & 15 deletions master/internal/api_shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package internal
import (
"archive/tar"
"context"
"encoding/json"
"fmt"
"strconv"

Expand Down Expand Up @@ -253,20 +252,7 @@ func (a *apiServer) LaunchShell(
}
maps.Copy(launchReq.Spec.Base.ExtraEnvVars, oidcPachydermEnvVars)

var passphrase *string
if len(req.Data) > 0 {
var data map[string]interface{}
if err = json.Unmarshal(req.Data, &data); err != nil {
return nil, status.Errorf(codes.Internal, "failed to parse data %s: %s", req.Data, err)
}
if pwd, ok := data["passphrase"]; ok {
if typed, typedOK := pwd.(string); typedOK {
passphrase = &typed
}
}
}

keys, err := ssh.GenerateKey(launchReq.Spec.Base.SSHRsaSize, passphrase)
keys, err := ssh.GenerateKey(launchReq.Spec.Base.SSHConfig)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
Expand Down
2 changes: 1 addition & 1 deletion master/internal/api_user_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func setupAPITest(t *testing.T, pgdb *db.PgDB,
TaskContainerDefaults: model.TaskContainerDefaultsConfig{},
ResourceConfig: *config.DefaultResourceConfig(),
},
taskSpec: &tasks.TaskSpec{SSHRsaSize: 1024},
taskSpec: &tasks.TaskSpec{SSHConfig: config.SSHConfig{KeyType: "ED25519"}},
allRms: map[string]rm.ResourceManager{config.DefaultClusterName: mockRM},
},
}
Expand Down
27 changes: 21 additions & 6 deletions master/internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ const (
preemptionScheduler = "preemption"
)

const (
// KeyTypeRSA uses RSA.
KeyTypeRSA = "RSA"
// KeyTypeECDSA uses ECDSA.
KeyTypeECDSA = "ECDSA"
// KeyTypeED25519 uses ED25519.
KeyTypeED25519 = "ED25519"
)

type (
// ExperimentConfigPatch is the updatedble fields for patching an experiment.
ExperimentConfigPatch struct {
Expand Down Expand Up @@ -108,7 +117,7 @@ func DefaultConfig() *Config {
Group: "root",
},
SSH: SSHConfig{
RsaKeySize: 1024,
KeyType: KeyTypeED25519,
},
AuthZ: *DefaultAuthZConfig(),
},
Expand Down Expand Up @@ -452,7 +461,8 @@ type SecurityConfig struct {

// SSHConfig is the configuration setting for SSH.
type SSHConfig struct {
RsaKeySize int `json:"rsa_key_size"`
RsaKeySize int `json:"rsa_key_size"`
KeyType string `json:"key_type"`
}

// TLSConfig is the configuration for setting up serving over TLS.
Expand All @@ -475,10 +485,15 @@ func (t *TLSConfig) Validate() []error {
// Validate implements the check.Validatable interface.
func (t *SSHConfig) Validate() []error {
var errs []error
if t.RsaKeySize < 1 {
errs = append(errs, errors.New("RSA Key size must be greater than 0"))
} else if t.RsaKeySize > 16384 {
errs = append(errs, errors.New("RSA Key size must be less than 16,384"))
if t.KeyType != KeyTypeRSA && t.KeyType != KeyTypeECDSA && t.KeyType != KeyTypeED25519 {
errs = append(errs, errors.New("Crypto system must be one of 'RSA', 'ECDSA' or 'ED25519'"))
}
if t.KeyType == KeyTypeRSA {
if t.RsaKeySize < 1 {
errs = append(errs, errors.New("RSA Key size must be greater than 0"))
} else if t.RsaKeySize > 16384 {
errs = append(errs, errors.New("RSA Key size must be less than 16,384"))
}
}
return errs
}
Expand Down
2 changes: 1 addition & 1 deletion master/internal/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -1242,7 +1242,7 @@ func (m *Master) Run(ctx context.Context, gRPCLogInitDone chan struct{}) error {
HarnessPath: filepath.Join(m.config.Root, "wheels"),
TaskContainerDefaults: m.config.TaskContainerDefaults,
MasterCert: config.GetCertPEM(cert),
SSHRsaSize: m.config.Security.SSH.RsaKeySize,
SSHConfig: m.config.Security.SSH,
SegmentEnabled: m.config.Telemetry.Enabled && m.config.Telemetry.SegmentMasterKey != "",
SegmentAPIKey: m.config.Telemetry.SegmentMasterKey,
LogRetentionDays: m.config.RetentionPolicy.LogRetentionDays,
Expand Down
2 changes: 1 addition & 1 deletion master/internal/core_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func TestRun(t *testing.T) {
DefaultLoggingConfig: &model.DefaultLoggingConfig{},
},
},
taskSpec: &tasks.TaskSpec{SSHRsaSize: 1024},
taskSpec: &tasks.TaskSpec{SSHConfig: config.SSHConfig{KeyType: "ED25519"}},
}
require.NoError(t, m.config.Resolve())
m.config.DB = config.DBConfig{
Expand Down
2 changes: 1 addition & 1 deletion master/internal/experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func newExperiment(

taskSpec.AgentUserGroup = agentUserGroup

generatedKeys, err := ssh.GenerateKey(taskSpec.SSHRsaSize, nil)
generatedKeys, err := ssh.GenerateKey(taskSpec.SSHConfig)
if err != nil {
return nil, nil, errors.Wrap(err, "generating ssh keys for trials")
}
Expand Down
1 change: 0 additions & 1 deletion master/internal/trial_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ func setup(t *testing.T) (
&model.Checkpoint{},
&tasks.TaskSpec{
AgentUserGroup: &model.AgentUserGroup{},
SSHRsaSize: 1024,
Workspace: model.DefaultWorkspaceName,
},
ssh.PrivateAndPublicKeys{},
Expand Down
91 changes: 76 additions & 15 deletions master/pkg/ssh/ssh.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
package ssh

import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"

"github.com/pkg/errors"
sshlib "golang.org/x/crypto/ssh"

"github.com/determined-ai/determined/master/internal/config"
)

const (
trialPEMBlockType = "RSA PRIVATE KEY"
rsaPEMBlockType = "RSA PRIVATE KEY"
ecdsaPEMBlockType = "EC PRIVATE KEY"
)

// PrivateAndPublicKeys contains a private and public key.
Expand All @@ -21,40 +27,95 @@ type PrivateAndPublicKeys struct {
}

// GenerateKey returns a private and public SSH key.
func GenerateKey(rsaKeySize int, passphrase *string) (PrivateAndPublicKeys, error) {
func GenerateKey(conf config.SSHConfig) (PrivateAndPublicKeys, error) {
var generatedKeys PrivateAndPublicKeys
switch conf.KeyType {
case config.KeyTypeRSA:
return generateRSAKey(conf.RsaKeySize)
case config.KeyTypeECDSA:
return generateECDSAKey()
case config.KeyTypeED25519:
return generateED25519Key()
default:
return generatedKeys, errors.New("Invalid crypto system")
}
}

func generateRSAKey(rsaKeySize int) (PrivateAndPublicKeys, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, rsaKeySize)
if err != nil {
return generatedKeys, errors.Wrap(err, "unable to generate private key")
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate RSA private key")
}

if err = privateKey.Validate(); err != nil {
return generatedKeys, err
return PrivateAndPublicKeys{}, err
}

block := &pem.Block{
Type: trialPEMBlockType,
Type: rsaPEMBlockType,
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
}

if passphrase != nil {
// TODO: Replace usage of deprecated x509.EncryptPEMBlock.
block, err = x509.EncryptPEMBlock( //nolint: staticcheck
rand.Reader, block.Type, block.Bytes, []byte(*passphrase), x509.PEMCipherAES256)
if err != nil {
return generatedKeys, errors.Wrap(err, "unable to encrypt private key")
}
publicKey, err := sshlib.NewPublicKey(&privateKey.PublicKey)
if err != nil {
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate RSA public key")
}

return PrivateAndPublicKeys{
PrivateKey: pem.EncodeToMemory(block),
PublicKey: sshlib.MarshalAuthorizedKey(publicKey),
}, nil
}

func generateECDSAKey() (PrivateAndPublicKeys, error) {
// Curve size currently not configurable, using the NIST recommendation.
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just as a note: I don't think there were any requests to make the curve size configurable, and P256 is the NIST recommendation, so this is fine. But one day we may have to update it or make configurable.

if err != nil {
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate ECDSA private key")
}

privateKeyBytes, err := x509.MarshalECPrivateKey(privateKey)
if err != nil {
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to marshal ECDSA private key")
}

block := &pem.Block{
Type: ecdsaPEMBlockType,
Bytes: privateKeyBytes,
}

publicKey, err := sshlib.NewPublicKey(&privateKey.PublicKey)
if err != nil {
return generatedKeys, errors.Wrap(err, "unable to generate public key")
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate ECDSA public key")
}

generatedKeys = PrivateAndPublicKeys{
return PrivateAndPublicKeys{
PrivateKey: pem.EncodeToMemory(block),
PublicKey: sshlib.MarshalAuthorizedKey(publicKey),
}, nil
}

func generateED25519Key() (PrivateAndPublicKeys, error) {

ed25519PublicKey, privateKey, err := ed25519.GenerateKey(nil)
if err != nil {
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate ED25519 private key")
}

// Before OpenSSH 9.6, for ED25519 keys, only the OpenSSH private key format was supported.
block, err := sshlib.MarshalPrivateKey(privateKey, "")
if err != nil {
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to marshal ED25519 private key")
}

publicKey, err := sshlib.NewPublicKey(ed25519PublicKey)
if err != nil {
return PrivateAndPublicKeys{}, errors.Wrap(err, "unable to generate ED25519 public key")
}

return generatedKeys, nil
return PrivateAndPublicKeys{
PrivateKey: pem.EncodeToMemory(block),
PublicKey: sshlib.MarshalAuthorizedKey(publicKey),
}, nil

}
39 changes: 39 additions & 0 deletions master/pkg/ssh/ssh_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package ssh

import (
"testing"

"golang.org/x/crypto/ssh"
"gotest.tools/assert"

"github.com/determined-ai/determined/master/internal/config"
)

func verifyKeys(t *testing.T, keys PrivateAndPublicKeys) {
privateKey, err := ssh.ParsePrivateKey(keys.PrivateKey)
assert.NilError(t, err)

publickKey, _, _, _, err := ssh.ParseAuthorizedKey(keys.PublicKey) //nolint:dogsled
assert.NilError(t, err)
assert.Equal(t, string(publickKey.Marshal()), string(privateKey.PublicKey().Marshal()))
}

func TestSSHKeyGenerate(t *testing.T) {
t.Run("generate RSA key", func(t *testing.T) {
keys, err := GenerateKey(config.SSHConfig{KeyType: config.KeyTypeRSA, RsaKeySize: 512})
assert.NilError(t, err)
verifyKeys(t, keys)
})

t.Run("generate ECDSA key", func(t *testing.T) {
keys, err := GenerateKey(config.SSHConfig{KeyType: config.KeyTypeECDSA})
assert.NilError(t, err)
verifyKeys(t, keys)
})

t.Run("generate ED25519 key", func(t *testing.T) {
keys, err := GenerateKey(config.SSHConfig{KeyType: config.KeyTypeED25519})
assert.NilError(t, err)
verifyKeys(t, keys)
})
}
3 changes: 2 additions & 1 deletion master/pkg/tasks/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/docker/docker/api/types/mount"
"github.com/jinzhu/copier"

"github.com/determined-ai/determined/master/internal/config"
"github.com/determined-ai/determined/master/pkg/archive"
"github.com/determined-ai/determined/master/pkg/cproto"
"github.com/determined-ai/determined/master/pkg/device"
Expand Down Expand Up @@ -70,7 +71,7 @@ type TaskSpec struct {
ClusterID string
HarnessPath string
MasterCert []byte
SSHRsaSize int
SSHConfig config.SSHConfig

SegmentEnabled bool
SegmentAPIKey string
Expand Down
2 changes: 1 addition & 1 deletion proto/pkg/apiv1/shell.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion proto/src/determined/api/v1/shell.proto
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ message LaunchShellRequest {
string template_name = 2;
// The files to run with the command.
repeated determined.util.v1.File files = 3;
// Additional data.
// Deprecated: Do not use.
bytes data = 4;
// Workspace ID. Defaults to 'Uncategorized' workspace if not specified.
int32 workspace_id = 5;
Expand Down
Loading
Loading