diff --git a/pkg/common/test/client.go b/pkg/common/test/client.go index caacd52..5133953 100644 --- a/pkg/common/test/client.go +++ b/pkg/common/test/client.go @@ -33,5 +33,5 @@ func (fc *FilesComClient) GetFiles(dirs []string) ([]db.File, error) { } func (fc *FilesComClient) Download(toDownload *db.File, downloadPath string) (*files_sdk.File, error) { - return nil, nil + return &files_sdk.File{}, nil } diff --git a/pkg/config/config.go b/pkg/config/config.go index 4aa316d..c9fc44b 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -34,6 +34,7 @@ type Config struct { PollEvery string `yaml:"poll-every" default:"5"` FilesDelta string `yaml:"files-delta" default:"10m"` Filetypes []string `yaml:"filetypes"` + BaseTmpDir string `yaml:"base-tmpdir" default:""` Directories []string `yaml:"directories"` ProcessorMap []struct { Type string `yaml:"type"` diff --git a/pkg/monitor/monitor.go b/pkg/monitor/monitor.go index 31d5d89..969468e 100644 --- a/pkg/monitor/monitor.go +++ b/pkg/monitor/monitor.go @@ -3,6 +3,11 @@ package monitor import ( "context" "fmt" + "os" + "regexp" + "sync" + "time" + "github.com/canonical/athena-core/pkg/common" "github.com/canonical/athena-core/pkg/common/db" "github.com/canonical/athena-core/pkg/config" @@ -11,9 +16,6 @@ import ( "github.com/lileio/pubsub/v2" "github.com/lileio/pubsub/v2/middleware/defaults" log "github.com/sirupsen/logrus" - "regexp" - "sync" - "time" ) type Monitor struct { @@ -148,6 +150,24 @@ func (m *Monitor) PollNewFiles(ctx *context.Context, duration time.Duration) { log.Infof("File %s already dispatched, skipping", file.Path) continue } + log.Infof("Downloading file %s to shared folder", file.Path) + basePath := m.Config.Monitor.BaseTmpDir + if basePath == "" { + basePath = "/tmp" + } + log.Debugf("Using temporary base path: %s", basePath) + fileEntry, err := m.FilesClient.Download(&file, basePath) + if err != nil { + log.Errorf("Failed to download %s: %s", file.Path, err) + } + log.Infof("Downloaded file to %s", fileEntry.Path) + if _, err := os.Stat(basePath); os.IsNotExist(err) { + log.Debugf("Temporary base path '%s' doesn't exist - creating", basePath) + if err = os.MkdirAll(basePath, 0755); err != nil { + log.Errorf("Cannot create temporary base path: %s", err.Error()) + } + } + log.Infof("Sending file: %s to processor: %s", file.Path, processor) publishResults := pubsub.PublishJSON(*ctx, processor, file) if publishResults.Err != nil { diff --git a/pkg/processor/processor.go b/pkg/processor/processor.go index 3c31a19..b825623 100644 --- a/pkg/processor/processor.go +++ b/pkg/processor/processor.go @@ -247,11 +247,6 @@ func NewReportRunner(cfg *config.Config, dbConn *gorm.DB, sf common.SalesforceCl return nil, err } - fileEntry, err := fc.Download(file, dir) - if err != nil { - return nil, err - } - reportRunner.Config = cfg reportRunner.Subscriber = subscriber reportRunner.Name = name @@ -262,9 +257,9 @@ func NewReportRunner(cfg *config.Config, dbConn *gorm.DB, sf common.SalesforceCl //TODO: document the template variables tplContext := pongo2.Context{ - "basedir": reportRunner.Basedir, // base dir used to generate reports - "file": fileEntry, // file entry as returned by the files.com api client - "filepath": path.Join(reportRunner.Basedir, filepath.Base(fileEntry.Path)), // directory where the file lives on + "basedir": reportRunner.Basedir, // base dir used to generate reports + "file": filepath.Base(file.Path), // file entry as returned by the files.com api client + "filepath": path.Join(reportRunner.Basedir, filepath.Base(file.Path)), // directory where the file lives on } var scripts = make(map[string]string) @@ -300,6 +295,12 @@ func NewReportRunner(cfg *config.Config, dbConn *gorm.DB, sf common.SalesforceCl scripts[scriptName] = fd.Name() } + log.Infof("Removing previously downloaded file: %s", filepath.Base(file.Path)) + err = os.Remove(path.Join(basePath, filepath.Base(file.Path))) + if err != nil { + log.Errorf("Could not remove %s: %s", filepath.Base(file.Path), err.Error()) + } + timeout, err := time.ParseDuration(report.Timeout) if err != nil { timeout, _ = time.ParseDuration(DefaultExecutionTimeout)