package main import ( "context" "crypto/subtle" "errors" "fmt" "io" "log" "net" "net/http" "os" "path/filepath" "strings" "sync" "syscall" "github.com/mholt/archives" "github.com/otiai10/copy" ) var ( authenticationKey = os.Getenv("AUTH_KEY") subPath = strings.Trim(os.Getenv("SUBPATH"), "/") targetDirectory = os.Getenv("TARGET_DIR") port = os.Getenv("PORT") maxUploadSize int64 = 1 << 30 // 1GB deployLock sync.Mutex ) func main() { if authenticationKey == "" || targetDirectory == "" { log.Fatal("AUTH_KEY and TARGET_DIR must be set") } if port == "" { port = "8080" } basePath := "/" if subPath != "" { basePath = "/" + subPath } log.Printf("starting server on :%s, endpoint %q, target directory %q", port, basePath, targetDirectory) http.HandleFunc(basePath, withRecovery(handle)) log.Fatal(http.ListenAndServe(":"+port, nil)) } func handle(w http.ResponseWriter, r *http.Request) { remoteIP := realIP(r) log.Printf("incoming %q request on %q from %s", r.Method, r.URL.Path, remoteIP) if r.Method != http.MethodPost { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } auth := r.Header.Get("Authorization") if subtle.ConstantTimeCompare([]byte(auth), []byte(authenticationKey)) != 1 { log.Printf("unauthorized request from %s", remoteIP) http.Error(w, "unauthorized", http.StatusUnauthorized) return } r.Body = http.MaxBytesReader(w, r.Body, maxUploadSize) uploadDir, err := os.MkdirTemp("", "upload-*") if err != nil { http.Error(w, "internal error", http.StatusInternalServerError) return } defer os.RemoveAll(uploadDir) uploadFilePath := filepath.Join(uploadDir, "file") uploadFile, err := os.Create(uploadFilePath) if err != nil { http.Error(w, "internal error", http.StatusInternalServerError) return } defer uploadFile.Close() if _, err := io.Copy(uploadFile, r.Body); err != nil { http.Error(w, "bad archive", http.StatusBadRequest) return } uploadFile.Seek(0, io.SeekStart) defer uploadFile.Close() ctx := context.Background() format, archiveStream, err := archives.Identify(ctx, uploadFilePath, uploadFile) if err != nil { http.Error(w, "cannot detect archive", http.StatusBadRequest) return } extractor, ok := format.(archives.Extractor) if !ok { http.Error(w, "unsupported archive type", http.StatusBadRequest) return } extractDir, err := os.MkdirTemp("", "extract-*") if err != nil { http.Error(w, "internal error", http.StatusInternalServerError) return } defer os.RemoveAll(extractDir) if err := extractor.Extract(ctx, archiveStream, extract(extractDir)); err != nil { log.Printf("failed to extract archive: %v", err) http.Error(w, "bad archive", http.StatusBadRequest) return } deployLock.Lock() defer deployLock.Unlock() if err := removeContents(targetDirectory); err != nil { http.Error(w, "internal error", http.StatusInternalServerError) return } if err := renameOrCopyContents(extractDir); err != nil { http.Error(w, "internal error", http.StatusInternalServerError) return } w.WriteHeader(http.StatusOK) log.Printf("upload successful from %s", remoteIP) } func realIP(r *http.Request) string { if xff := r.Header.Get("X-Forwarded-For"); xff != "" { parts := strings.Split(xff, ",") return strings.TrimSpace(parts[0]) } if xrip := r.Header.Get("X-Real-IP"); xrip != "" { return xrip } host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return r.RemoteAddr } return host } func extract(base string) archives.FileHandler { return func(ctx context.Context, f archives.FileInfo) error { if f.Mode()&os.ModeSymlink != 0 { return fmt.Errorf("refusing to extract symlink: %s", f.NameInArchive) } targetPath := filepath.Join(base, f.NameInArchive) if !strings.HasPrefix(filepath.Clean(targetPath), filepath.Clean(base)+string(os.PathSeparator)) { return fmt.Errorf("invalid path: %s", f.NameInArchive) } if f.IsDir() { return os.MkdirAll(targetPath, os.FileMode(f.Mode())) } if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil { return err } in, err := f.Open() if err != nil { return err } defer in.Close() out, err := os.OpenFile(targetPath, os.O_CREATE|os.O_RDWR|os.O_TRUNC, os.FileMode(f.Mode())) if err != nil { return err } defer out.Close() _, err = io.Copy(out, in) return err } } func removeContents(directory string) error { entries, err := os.ReadDir(directory) if err != nil { return fmt.Errorf("failed to read dir %s: %w", directory, err) } for _, entry := range entries { err := os.RemoveAll(filepath.Join(directory, entry.Name())) if err != nil { return err } } return nil } func renameOrCopyContents(sourceDirectory string) error { entries, err := os.ReadDir(sourceDirectory) if err != nil { return fmt.Errorf("failed to read dir %s: %w", sourceDirectory, err) } for _, entry := range entries { srcPath := filepath.Join(sourceDirectory, entry.Name()) dstPath := filepath.Join(targetDirectory, entry.Name()) if err := os.Rename(srcPath, dstPath); err != nil { if isCrossDeviceErr(err) { if err := copy.Copy(srcPath, dstPath); err != nil { return err } _ = os.RemoveAll(srcPath) } else { return err } } } return nil } func isCrossDeviceErr(err error) bool { var linkErr *os.LinkError return errors.As(err, &linkErr) && linkErr.Err == syscall.EXDEV } func withRecovery(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { defer func() { if v := recover(); v != nil { log.Printf("panic: %v", v) http.Error(w, "internal error", http.StatusInternalServerError) } }() next(w, r) } }