diff --git a/cmd/wp-go-static/commands/root.go b/cmd/wp-go-static/commands/root.go index 9389909..c5a6979 100644 --- a/cmd/wp-go-static/commands/root.go +++ b/cmd/wp-go-static/commands/root.go @@ -3,6 +3,7 @@ package commands import ( "crypto/tls" "fmt" + "log" "net/http" "net/url" "regexp" @@ -70,21 +71,17 @@ func init() { RootCmd.PersistentFlags().String("cache", "", "Cache directory") RootCmd.PersistentFlags().Bool("parallel", false, "Fetch in parallel") RootCmd.PersistentFlags().Bool("images", true, "Download images") + RootCmd.MarkFlagRequired("url") // Bind command-line flags to Viper - viper.BindPFlag("dir", RootCmd.PersistentFlags().Lookup("dir")) - viper.BindPFlag("url", RootCmd.PersistentFlags().Lookup("url")) - viper.BindPFlag("cache", RootCmd.PersistentFlags().Lookup("cache")) - viper.BindPFlag("images", RootCmd.PersistentFlags().Lookup("images")) + err := viper.BindPFlags(RootCmd.PersistentFlags()) + if err != nil { + log.Fatal(err) + } viper.AutomaticEnv() viper.EnvKeyReplacer(strings.NewReplacer("-", "_")) viper.SetEnvPrefix("WGS") - - // Execute root command - if err := RootCmd.Execute(); err != nil { - fmt.Println(err) - } } func rootCmdF(command *cobra.Command, args []string) error { @@ -103,9 +100,9 @@ func rootCmdF(command *cobra.Command, args []string) error { scrape.c.Async = parallel - // Ignore SSL errors + // Use a custom TLS config to verify server certificates scrape.c.WithTransport(&http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + TLSClientConfig: &tls.Config{}, }) parsedURL, err := url.Parse(commandURL) @@ -120,31 +117,44 @@ func rootCmdF(command *cobra.Command, args []string) error { // On every a element which has href attribute call callback scrape.c.OnHTML("a[href]", func(e *colly.HTMLElement) { link := e.Attr("href") - scrape.visitURL(e.Request.AbsoluteURL(link)) + scrape.visitURL(link) }) // On every link element call callback scrape.c.OnHTML("link[href]", func(e *colly.HTMLElement) { link := e.Attr("href") - scrape.visitURL(e.Request.AbsoluteURL(link)) + scrape.visitURL(link) }) // On every script element call callback scrape.c.OnHTML("script[src]", func(e *colly.HTMLElement) { link := e.Attr("src") - scrape.visitURL(e.Request.AbsoluteURL(link)) + scrape.visitURL(link) }) // On every img element call callback - scrape.c.OnHTML("img", func(e *colly.HTMLElement) { - link := e.Attr("src") - scrape.visitURL(e.Request.AbsoluteURL(link)) + src := e.Attr("src") + srcSet := e.Attr("srcset") + scrape.visitURL(src) + + if srcSet != "" { + srcSetList := strings.Split(srcSet, ",") + for _, srcSetURL := range srcSetList { + srcSetURL = strings.TrimSpace(srcSetURL) + innerSrcSet := strings.Split(srcSetURL, " ")[0] + if innerSrcSet == "" { + continue + } + + scrape.visitURL(innerSrcSet) + } + } }) // Before making a request print "Visiting ..." scrape.c.OnRequest(func(r *colly.Request) { - fmt.Println("Visiting", r.URL.String()) + log.Println("Visiting", r.URL.String()) }) // On response @@ -154,7 +164,7 @@ func rootCmdF(command *cobra.Command, args []string) error { err := file.SaveFile(r, dir, fileName) if err != nil { - fmt.Println(err) + log.Println(err) return } }) @@ -172,15 +182,38 @@ func rootCmdF(command *cobra.Command, args []string) error { } func (s *Scrape) visitURL(link string) { - // Download image found on page if it hasn't been visited before + link = s.getAbsoluteURL(link) + + if link == "" { + return + } + + u, err := url.Parse(link) + if err != nil { + log.Printf("Error parsing URL %s: %s", link, err) + return + } + + if u.Scheme == "" || u.Host == "" { + log.Printf("Invalid URL %s", link) + return + } + + u.Fragment = "" + + link = u.String() + + // Download page if it hasn't been visited before if !s.urlCache.Get(link) { s.urlCache.Add(link) - s.c.Visit(link) + err := s.c.Visit(link) + if err != nil { + log.Println(err) + } } } func (s *Scrape) parseBody(body []byte) []byte { - // Find all URLs in the CSS file cssUrls := regexp.MustCompile(`url\((https?://[^\s]+)\)`).FindAllStringSubmatch(string(body), -1) // Download each referenced file if it hasn't been visited before @@ -208,3 +241,21 @@ func (s *Scrape) parseBody(body []byte) []byte { return body } + +func (s *Scrape) getAbsoluteURL(inputURL string) string { + parsedURL, err := url.Parse(inputURL) + if err != nil { + return "" + } + + if parsedURL.IsAbs() { + return inputURL + } + + baseURL, err := url.Parse(s.domain) + if err != nil { + return "" + } + + return baseURL.ResolveReference(parsedURL).String() +}