diff --git a/.gitignore b/.gitignore index c6b2ccddd..d8d9c611b 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,4 @@ tusd_*_* __pycache__/ examples/hooks/plugin/hook_handler .idea/ -.vscode/ \ No newline at end of file +.vscode/ diff --git a/cmd/tusd/cli/flags.go b/cmd/tusd/cli/flags.go index f77638d4c..2ec2c6c86 100644 --- a/cmd/tusd/cli/flags.go +++ b/cmd/tusd/cli/flags.go @@ -55,6 +55,10 @@ var Flags struct { GrpcHooksEndpoint string GrpcHooksRetry int GrpcHooksBackoff time.Duration + GrpcHooksSecure bool + GrpcHooksServerTLSCertFile string + GrpcHooksClientTLSCertFile string + GrpcHooksClientTLSKeyFile string EnabledHooks []hooks.HookType ProgressHooksInterval time.Duration ShowVersion bool @@ -165,6 +169,10 @@ func ParseFlags() { f.StringVar(&Flags.GrpcHooksEndpoint, "hooks-grpc", "", "An gRPC endpoint to which hook events will be sent to") f.IntVar(&Flags.GrpcHooksRetry, "hooks-grpc-retry", 3, "Number of times to retry on a server error or network timeout") f.DurationVar(&Flags.GrpcHooksBackoff, "hooks-grpc-backoff", 1*time.Second, "Wait period before retrying each retry") + f.BoolVar(&Flags.GrpcHooksSecure, "hooks-grpc-secure", false, "Enables secure connection via TLS certificates to the specified gRPC endpoint") + f.StringVar(&Flags.GrpcHooksServerTLSCertFile, "hooks-grpc-server-tls-certificate", "", "Path to the file containing the TLS certificate of the remote gRPC server") + f.StringVar(&Flags.GrpcHooksClientTLSCertFile, "hooks-grpc-client-tls-certificate", "", "Path to the file containing the client certificate for mTLS") + f.StringVar(&Flags.GrpcHooksClientTLSKeyFile, "hooks-grpc-client-tls-key", "", "Path to the file containing the client key for mTLS") }) fs.AddGroup("Plugin hook options", func(f *flag.FlagSet) { diff --git a/cmd/tusd/cli/hooks.go b/cmd/tusd/cli/hooks.go index 3df8024af..7ec97fb26 100644 --- a/cmd/tusd/cli/hooks.go +++ b/cmd/tusd/cli/hooks.go @@ -31,9 +31,13 @@ func getHookHandler(config *handler.Config) hooks.HookHandler { stdout.Printf("Using '%s' as the endpoint for gRPC hooks", Flags.GrpcHooksEndpoint) return &grpc.GrpcHook{ - Endpoint: Flags.GrpcHooksEndpoint, - MaxRetries: Flags.GrpcHooksRetry, - Backoff: Flags.GrpcHooksBackoff, + Endpoint: Flags.GrpcHooksEndpoint, + MaxRetries: Flags.GrpcHooksRetry, + Backoff: Flags.GrpcHooksBackoff, + Secure: Flags.GrpcHooksSecure, + ServerTLSCertificateFilePath: Flags.GrpcHooksServerTLSCertFile, + ClientTLSCertificateFilePath: Flags.GrpcHooksClientTLSCertFile, + ClientTLSCertificateKeyFilePath: Flags.GrpcHooksClientTLSKeyFile, } } else if Flags.PluginHookPath != "" { stdout.Printf("Using '%s' to load plugin for hooks", Flags.PluginHookPath) diff --git a/pkg/hooks/grpc/grpc.go b/pkg/hooks/grpc/grpc.go index 92b0e1c45..3bc4a4c88 100644 --- a/pkg/hooks/grpc/grpc.go +++ b/pkg/hooks/grpc/grpc.go @@ -5,32 +5,78 @@ package grpc import ( "context" + "crypto/tls" + "crypto/x509" + "errors" "net/http" + "os" "time" grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" "github.com/tus/tusd/v2/pkg/hooks" pb "github.com/tus/tusd/v2/pkg/hooks/grpc/proto" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" ) type GrpcHook struct { - Endpoint string - MaxRetries int - Backoff time.Duration - Client pb.HookHandlerClient + Endpoint string + MaxRetries int + Backoff time.Duration + Client pb.HookHandlerClient + Secure bool + ServerTLSCertificateFilePath string + ClientTLSCertificateFilePath string + ClientTLSCertificateKeyFilePath string } func (g *GrpcHook) Setup() error { + grpcOpts := []grpc.DialOption{} + + if g.Secure { + if g.ServerTLSCertificateFilePath == "" { + return errors.New("hooks-grpc-secure was set to true but no gRPC server TLS certificate file was provided. A value for hooks-grpc-server-tls-certificate is missing") + } + + // Load the server's TLS certificate if provided + serverCert, err := os.ReadFile(g.ServerTLSCertificateFilePath) + if err != nil { + return err + } + + // Create a certificate pool and add the server's certificate + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(serverCert) + + // Create TLS configuration with the server's CA certificate + tlsConfig := &tls.Config{ + RootCAs: certPool, + } + + // If client's TLS certificate and key file paths are provided, use mutual TLS + if g.ClientTLSCertificateFilePath != "" && g.ClientTLSCertificateKeyFilePath != "" { + // Load the client's TLS certificate and private key + clientCert, err := tls.LoadX509KeyPair(g.ClientTLSCertificateFilePath, g.ClientTLSCertificateKeyFilePath) + if err != nil { + return err + } + + // Append client certificate to the TLS configuration + tlsConfig.Certificates = append(tlsConfig.Certificates, clientCert) + } + + grpcOpts = append(grpcOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) + } else { + grpcOpts = append(grpcOpts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + opts := []grpc_retry.CallOption{ grpc_retry.WithBackoff(grpc_retry.BackoffLinear(g.Backoff)), grpc_retry.WithMax(uint(g.MaxRetries)), } - grpcOpts := []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithUnaryInterceptor(grpc_retry.UnaryClientInterceptor(opts...)), - } + grpcOpts = append(grpcOpts, grpc.WithUnaryInterceptor(grpc_retry.UnaryClientInterceptor(opts...))) + conn, err := grpc.Dial(g.Endpoint, grpcOpts...) if err != nil { return err