245 lines
5.7 KiB
Go
245 lines
5.7 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/subtle"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
|
|
"github.com/mholt/archives"
|
|
"github.com/otiai10/copy"
|
|
)
|
|
|
|
var (
|
|
authenticationKey = os.Getenv("AUTH_KEY")
|
|
subPath = strings.Trim(os.Getenv("SUBPATH"), "/")
|
|
targetDirectory = os.Getenv("TARGET_DIR")
|
|
port = os.Getenv("PORT")
|
|
maxUploadSize int64 = 1 << 30 // 1GB
|
|
|
|
deployLock sync.Mutex
|
|
)
|
|
|
|
func main() {
|
|
if authenticationKey == "" || targetDirectory == "" {
|
|
log.Fatal("AUTH_KEY and TARGET_DIR must be set")
|
|
}
|
|
|
|
if port == "" {
|
|
port = "8080"
|
|
}
|
|
|
|
basePath := "/"
|
|
if subPath != "" {
|
|
basePath = "/" + subPath
|
|
}
|
|
|
|
log.Printf("starting server on :%s, endpoint %q, target directory %q", port, basePath, targetDirectory)
|
|
http.HandleFunc(basePath, withRecovery(handle))
|
|
log.Fatal(http.ListenAndServe(":"+port, nil))
|
|
}
|
|
|
|
func handle(w http.ResponseWriter, r *http.Request) {
|
|
remoteIP := realIP(r)
|
|
|
|
log.Printf("incoming %q request on %q from %s", r.Method, r.URL.Path, remoteIP)
|
|
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
auth := r.Header.Get("Authorization")
|
|
if subtle.ConstantTimeCompare([]byte(auth), []byte(authenticationKey)) != 1 {
|
|
log.Printf("unauthorized request from %s", remoteIP)
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
r.Body = http.MaxBytesReader(w, r.Body, maxUploadSize)
|
|
|
|
uploadDir, err := os.MkdirTemp("", "upload-*")
|
|
if err != nil {
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer os.RemoveAll(uploadDir)
|
|
|
|
uploadFilePath := filepath.Join(uploadDir, "file")
|
|
uploadFile, err := os.Create(uploadFilePath)
|
|
if err != nil {
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer uploadFile.Close()
|
|
|
|
if _, err := io.Copy(uploadFile, r.Body); err != nil {
|
|
http.Error(w, "bad archive", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
uploadFile.Seek(0, io.SeekStart)
|
|
defer uploadFile.Close()
|
|
|
|
ctx := context.Background()
|
|
|
|
format, archiveStream, err := archives.Identify(ctx, uploadFilePath, uploadFile)
|
|
if err != nil {
|
|
http.Error(w, "cannot detect archive", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
extractor, ok := format.(archives.Extractor)
|
|
if !ok {
|
|
http.Error(w, "unsupported archive type", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
extractDir, err := os.MkdirTemp("", "extract-*")
|
|
if err != nil {
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer os.RemoveAll(extractDir)
|
|
|
|
if err := extractor.Extract(ctx, archiveStream, extract(extractDir)); err != nil {
|
|
log.Printf("failed to extract archive: %v", err)
|
|
http.Error(w, "bad archive", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
deployLock.Lock()
|
|
defer deployLock.Unlock()
|
|
|
|
if err := removeContents(targetDirectory); err != nil {
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if err := renameOrCopyContents(extractDir); err != nil {
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
log.Printf("upload successful from %s", remoteIP)
|
|
}
|
|
|
|
func realIP(r *http.Request) string {
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
parts := strings.Split(xff, ",")
|
|
return strings.TrimSpace(parts[0])
|
|
}
|
|
if xrip := r.Header.Get("X-Real-IP"); xrip != "" {
|
|
return xrip
|
|
}
|
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
return r.RemoteAddr
|
|
}
|
|
return host
|
|
}
|
|
|
|
func extract(base string) archives.FileHandler {
|
|
return func(ctx context.Context, f archives.FileInfo) error {
|
|
if f.Mode()&os.ModeSymlink != 0 {
|
|
return fmt.Errorf("refusing to extract symlink: %s", f.NameInArchive)
|
|
}
|
|
|
|
targetPath := filepath.Join(base, f.NameInArchive)
|
|
if !strings.HasPrefix(filepath.Clean(targetPath), filepath.Clean(base)+string(os.PathSeparator)) {
|
|
return fmt.Errorf("invalid path: %s", f.NameInArchive)
|
|
}
|
|
|
|
if f.IsDir() {
|
|
return os.MkdirAll(targetPath, os.FileMode(f.Mode()))
|
|
}
|
|
|
|
if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
|
|
return err
|
|
}
|
|
|
|
in, err := f.Open()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer in.Close()
|
|
|
|
out, err := os.OpenFile(targetPath, os.O_CREATE|os.O_RDWR|os.O_TRUNC, os.FileMode(f.Mode()))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer out.Close()
|
|
|
|
_, err = io.Copy(out, in)
|
|
return err
|
|
}
|
|
}
|
|
|
|
func removeContents(directory string) error {
|
|
entries, err := os.ReadDir(directory)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read dir %s: %w", directory, err)
|
|
}
|
|
|
|
for _, entry := range entries {
|
|
err := os.RemoveAll(filepath.Join(directory, entry.Name()))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func renameOrCopyContents(sourceDirectory string) error {
|
|
entries, err := os.ReadDir(sourceDirectory)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read dir %s: %w", sourceDirectory, err)
|
|
}
|
|
|
|
for _, entry := range entries {
|
|
srcPath := filepath.Join(sourceDirectory, entry.Name())
|
|
dstPath := filepath.Join(targetDirectory, entry.Name())
|
|
|
|
if err := os.Rename(srcPath, dstPath); err != nil {
|
|
if isCrossDeviceErr(err) {
|
|
if err := copy.Copy(srcPath, dstPath); err != nil {
|
|
return err
|
|
}
|
|
_ = os.RemoveAll(srcPath)
|
|
} else {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func isCrossDeviceErr(err error) bool {
|
|
var linkErr *os.LinkError
|
|
return errors.As(err, &linkErr) && linkErr.Err == syscall.EXDEV
|
|
}
|
|
|
|
func withRecovery(next http.HandlerFunc) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
defer func() {
|
|
if v := recover(); v != nil {
|
|
log.Printf("panic: %v", v)
|
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
|
}
|
|
}()
|
|
next(w, r)
|
|
}
|
|
}
|