diff --git a/main.py b/main.py new file mode 100755 index 0000000..8983817 --- /dev/null +++ b/main.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python3 + +""" +This Listen a port on 127.0.0.1 or [::1], (so no authorization +implement required), receive a .tar.gz file and extract it to +a specific path, and make it a soft link to the extracted files. + +Use the following command to upload the file content: +curl -H "Content-Type: application/octet-stream" --data-binary @main.py http://localhost:8080 + +The file content must be a legal .tar.gz file. There must not be +a subdirectory to contain other files +""" + +import os +import tarfile +import shutil +from typing import BinaryIO, Union +from http.server import BaseHTTPRequestHandler, HTTPServer +from datetime import datetime +import argparse +import logging + +class FileManager: + def __init__(self, archive_dir: str, extract_dir: str, symlink_path: str, + keep_archive: int, keep_extract: int, temp_dir: str = "/tmp"): + self.archive_dir = archive_dir + self.extract_dir = extract_dir + self.symlink_path = symlink_path + + self.keep_archive = keep_archive + self.keep_extract = keep_extract + + self.temp_dir = temp_dir + + def _get_archive_name(self) -> str: + time_str = datetime.now().isoformat(timespec="seconds") + return f"archive_{time_str}.tar.gz" + + def _extract(self, archive_path: str, target_path: str) -> bool: + try: + with tarfile.open(archive_path, mode="r:gz") as tf: + tf.extractall(target_path) + except Exception as e: + logging.error("Failed to extract tar file: {}".format(e)) + return False + return True + + + def save_file(self, src: BinaryIO, content_length: int) -> Union[str, None]: + archive_name = self._get_archive_name() + tgt_file = os.path.join(self.temp_dir, archive_name) + + logging.info("Saving to {}".format(tgt_file)) + + try: + f = open(tgt_file, "bw") + redirect_stream(src, f, content_length) + f.close() + except: + os.remove(tgt_file) + return None + + final_file = os.path.join(self.archive_dir, archive_name) + shutil.move(tgt_file, final_file) + return final_file + + def deploy(self, archive_path: str) -> bool: + extract_dir = os.path.join(self.extract_dir, os.path.basename(archive_path)) + + logging.info("Deploying to {}".format(extract_dir)) + + os.mkdir(extract_dir) + if not self._extract(archive_path, extract_dir): + return False + + os.remove(self.symlink_path) + os.symlink(extract_dir, self.symlink_path) + + return True + + def _vacuum_single(self, dirname: str, keep_count: int, rm_dir: bool) -> None: + files = os.listdir(dirname) + files.sort() + for f in files[:-keep_count]: + full_path = os.path.join(dirname, f) + + logging.info("Removing {}".format(full_path)) + if rm_dir: + shutil.rmtree(full_path) + else: + os.remove(full_path) + + def vacuum(self) -> None: + if self.keep_archive > 0: + logging.info("Vacuuming archive, keep {} finally".format(self.keep_archive)) + self._vacuum_single(self.archive_dir, self.keep_archive, False) + if self.keep_extract > 0: + logging.info("Vacuuming extract, keep {} finally".format(self.keep_extract)) + self._vacuum_single(self.extract_dir, self.keep_extract, True) + + def handle(self, instream: BinaryIO, content_length: int) -> bool: + archive_path = self.save_file(instream, content_length) + if archive_path is None: + logging.error("Failed to extract file. Aborted!") + return False + + if not self.deploy(archive_path): + return False + self.vacuum() + return True + +global_mgr: FileManager + + +def redirect_stream(src: BinaryIO, tgt: BinaryIO, size: int) -> None: + block_size = 4 * 1024 * 1024 # 4MB + + cache = src.read(size % block_size) + tgt.write(cache) + size -= size % block_size + + while size > 0: + cache = src.read(block_size) + tgt.write(cache) + size -= block_size + + +class S(BaseHTTPRequestHandler): + protocol_version = 'HTTP/1.1' + + def __init__(self, *args, **kwargs): + super(S, self).__init__(*args, **kwargs) + + def _set_response(self): + self.send_response(200) + self.send_header("Content-Type", "text/plaintext") + self.end_headers() + + # def do_GET(self): + # logging.info("GET request, Path=%s, Headers: %s", str(self.path), str(self.headers)) + # self._set_response() + # self.wfile.write("GET request for {}".format(self.path).encode("utf-8")) + + def _write_response(self, status_code: int, content_type: str, content: bytes): + self.send_response(status_code) + self.send_header("Content-Type", content_type) + self.send_header("Content-Length", str(len(content))) + self.end_headers() + + self.wfile.write(content) + + def do_POST(self): + content_length = int(self.headers["Content-Length"]) + if global_mgr.handle(self.rfile, content_length): + content = "Success".encode("utf-8") + self._write_response(200, "text/plaintext", content) + logging.info("Deploy success") + else: + content = "Failed".encode("utf-8") + self._write_response(200, "text/plaintext", content) + logging.error("Deploy failed") + self.wfile.flush() + + +def run(archive_dir: str, extract_dir: str, symlink_path: str, + keep_archive: int, keep_extract: int, + port: int = 8080, temp_dir: str = "/tmp"): + logging.basicConfig(level=logging.DEBUG) + + address = "127.0.0.1" + + logging.info("Listening on {}:{}".format(address, port)) + logging.info("Archive saves under: {}".format(archive_dir)) + logging.info("Extract tar under: {}".format(extract_dir)) + logging.info("Keep {} archives at most".format(keep_archive)) + logging.info("Keep {} extracted at most".format(keep_extract)) + logging.info("Symbolic link location: {}".format(symlink_path)) + logging.info("Temperory directory: {}".format(temp_dir)) + + global global_mgr + global_mgr = FileManager(archive_dir=archive_dir, extract_dir=extract_dir, + symlink_path=symlink_path, temp_dir=temp_dir, + keep_archive=keep_archive, keep_extract=keep_extract) + + httpd = HTTPServer((address, port), S) + logging.info("Starting httpd...") + + try: + httpd.serve_forever() + except KeyboardInterrupt: + pass + + httpd.server_close() + logging.info("Stopping httpd...") + + +if __name__ == "__main__": + ap = argparse.ArgumentParser() + ap.add_argument("--archive-dir", dest="archive_dir", type=str, + required=True, help="directory to save archives") + ap.add_argument("--extract-dir", dest="extract_dir", type=str, + required=True, help="directory to save extracted files") + ap.add_argument("--symlink-path", dest="symlink_path", type=str, + required=True, help="path of symlink which redirect to extracted archive") + + ap.add_argument("--keep-extract", dest="keep_extract", type=int, + default=4, help="Number of extracted archives to keep, 0 mean never vacuum") + ap.add_argument("--keep-archive", dest="keep_archive", type=int, + default=8, help="Number of archives to keep, 0 mean never vacuum") + + ap.add_argument("--port", dest="port", type=int, + required=True, help="listen port on 127.0.0.1, " + + "no authorization implemented so only listen on 127.0.0.1 for safety") + ap.add_argument("--temp-dir", dest="temp_dir", type=str, + default="/tmp", help="path to save in-delivery archive") + + args = ap.parse_args() + + run(archive_dir=args.archive_dir, + extract_dir=args.extract_dir, + symlink_path=args.symlink_path, + keep_archive=args.keep_archive, + keep_extract=args.keep_extract, + temp_dir=args.temp_dir, + port=args.port)