From 7227c56d3345c3de5dd00d699ea5d9355fa6b229 Mon Sep 17 00:00:00 2001 From: Wei Zang Date: Mon, 2 Sep 2024 19:47:00 +0800 Subject: [PATCH] feat: add cert pool --- pool.go | 144 +++++++++++++++++++++++++++++++++++++++++++++++++++ pool_test.go | 28 ++++++++++ 2 files changed, 172 insertions(+) create mode 100644 pool.go create mode 100644 pool_test.go diff --git a/pool.go b/pool.go new file mode 100644 index 0000000..a5433a9 --- /dev/null +++ b/pool.go @@ -0,0 +1,144 @@ +package appstore + +import ( + "crypto/x509" + "embed" + "io" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "regexp" + "strings" + "sync" +) + +//go:embed certs/*.cer +var certs embed.FS + +const srcUrl = "https://www.apple.com/certificateauthority/" +const outDir = "certs/" + +var certLinkPattern = regexp.MustCompile(`]*href="([^"]+\.cer)"`) + +type CertPool struct { + pool *x509.CertPool + poolOnce sync.Once +} + +func NewCertPool() (*CertPool, error) { + cp := &CertPool{} + err := cp.Init() + if err != nil { + return nil, err + } + return cp, nil +} + +func (cp *CertPool) Init() error { + var err error + cp.poolOnce.Do(func() { + cp.pool = x509.NewCertPool() + err = cp.downloadCerts() + err = cp.loadCerts() + }) + return err +} + +func (cp *CertPool) downloadCerts() error { + resp, err := http.Get(srcUrl) + if err != nil { + return err + } + defer resp.Body.Close() + + content, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + + if err := os.RemoveAll(outDir); err != nil { + return err + } + if err := os.MkdirAll(outDir, 0755); err != nil { + return err + } + + matches := certLinkPattern.FindAllSubmatch(content, -1) + for _, match := range matches { + certUrl, err := cp.constructCertUrl(string(match[1])) + if err != nil { + return err + } + + if err := cp.downloadAndSaveCert(certUrl); err != nil { + return err + } + } + return nil +} + +func (cp *CertPool) constructCertUrl(certPath string) (string, error) { + if certPath[0] == '/' { + baseUrl, err := url.Parse(srcUrl) + if err != nil { + return "", err + } + baseUrl.Path = certPath + return baseUrl.String(), nil + } else if strings.HasPrefix(certPath, "https://www.apple.com/") || strings.HasPrefix(certPath, "https://developer.apple.com/") { + return certPath, nil + } else { + return url.JoinPath(srcUrl, certPath) + } +} + +func (cp *CertPool) downloadAndSaveCert(certUrl string) error { + resp, err := http.Get(certUrl) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil + } + + fileName := path.Base(certUrl) + filePath := filepath.Join(outDir, fileName) + f, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return err + } + defer f.Close() + + _, err = io.Copy(f, resp.Body) + return err +} + +func (cp *CertPool) loadCerts() error { + entries, err := certs.ReadDir("certs") + if err != nil { + return err + } + for _, entry := range entries { + if !entry.IsDir() && entry.Type().IsRegular() { + cert, err := certs.ReadFile("certs/" + entry.Name()) + if err != nil { + continue + } + if ok := cp.pool.AppendCertsFromPEM(cert); ok { + continue + } + if cer, err := x509.ParseCertificate(cert); err == nil { + cp.pool.AddCert(cer) + } + } + } + return nil +} + +func (cp *CertPool) GetCertPool() *x509.CertPool { + return cp.pool +} diff --git a/pool_test.go b/pool_test.go new file mode 100644 index 0000000..eed0184 --- /dev/null +++ b/pool_test.go @@ -0,0 +1,28 @@ +package appstore + +import ( + "reflect" + "testing" +) + +func TestNewCertPool(t *testing.T) { + tests := []struct { + name string + want *CertPool + wantErr bool + }{ + {"test", &CertPool{}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewCertPool() + if (err != nil) != tt.wantErr { + t.Errorf("NewCertPool() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewCertPool() got = %v, want %v", got, tt.want) + } + }) + } +}