#!/usr/bin/env bash
set -euo pipefail

S_PORT=18101        # TLS server
P_PORT=18102        # HTTP CONNECT proxy
TMPDIR="$(mktemp -d)"
trap 'kill "$SPID" 2>/dev/null || true; kill "$PROXY_PID" 2>/dev/null || true; rm -rf "$TMPDIR"' EXIT

# Create a self-signed cert
openssl req -x509 -newkey rsa:2048 -nodes -keyout "$TMPDIR/key.pem" -out "$TMPDIR/cert.pem" -subj "/CN=127.0.0.1" -days 1 >/dev/null 2>&1

# Minimal HTTPS echo (Python + ssl)
cat >"$TMPDIR/https_echo.py" <<'PY'
import ssl, json, sys
from http.server import HTTPServer, BaseHTTPRequestHandler

class H(BaseHTTPRequestHandler):
    protocol_version = "HTTP/1.1"  # proper framing

    def _write_json(self, obj):
        data = json.dumps(obj).encode("utf-8")
        self.send_header("Content-Type", "application/json")
        self.send_header("Content-Length", str(len(data)))   # key fix
        self.send_header("Connection", "close")
        self.end_headers()
        self.wfile.write(data)
        self.wfile.flush()

    def do_ANY(self):
        length = int(self.headers.get('Content-Length', 0))
        body = self.rfile.read(length).decode() if length else ""
        out = {"path": self.path, "headers": dict(self.headers), "body": body}
        self.send_response(200)
        self._write_json(out)

    do_GET = do_ANY
    do_POST = do_ANY

PY

cat >"$TMPDIR/run_https.py" <<PY
import ssl, sys
import os, sys
sys.path.insert(0, '$TMPDIR')
from https_echo import H
from http.server import HTTPServer
port, crt, key = int(sys.argv[1]), sys.argv[2], sys.argv[3]
httpd = HTTPServer(('127.0.0.1', port), H)
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ctx.load_cert_chain(crt, keyfile=key)
httpd.socket = ctx.wrap_socket(httpd.socket, server_side=True)
httpd.serve_forever()
PY

# Simple HTTP CONNECT proxy implementation
cat >"$TMPDIR/connect_proxy.py" <<'PY'
import socket, sys, threading

RESP_OK = b"HTTP/1.1 200 Connection established\r\nProxy-Agent: mini-connect\r\n\r\n"

def read_request_headers(sock):
    # Read until \r\n\r\n (might arrive in fragments)
    buf = bytearray()
    while b"\r\n\r\n" not in buf:
        chunk = sock.recv(4096)
        if not chunk:
            break
        buf += chunk
        if len(buf) > 65536:
            break
    return bytes(buf)

def parse_connect_host_port(req_bytes):
    # First line like: CONNECT host:port HTTP/1.1\r\n
    first_line = req_bytes.split(b"\r\n", 1)[0].decode("latin1", "replace")
    parts = first_line.split()
    if len(parts) < 2 or parts[0].upper() != "CONNECT":
        return None, None
    hostport = parts[1]
    if ":" not in hostport:
        return None, None
    host, port = hostport.rsplit(":", 1)
    try:
        return host, int(port)
    except ValueError:
        return None, None

def pump(src, dst):
    # Copy bytes until EOF from src; ignore EPIPE
    try:
        while True:
            data = src.recv(8192)
            if not data:
                try:
                    dst.shutdown(socket.SHUT_WR)  # signal no more to write that way
                except Exception:
                    pass
                break
            dst.sendall(data)
    except Exception:
        try:
            dst.shutdown(socket.SHUT_WR)
        except Exception:
            pass

def handle_client(client_sock):
    target_sock = None
    try:
        req = read_request_headers(client_sock)
        host, port = parse_connect_host_port(req)
        if not host:
            client_sock.close()
            return

        target_sock = socket.create_connection((host, port), timeout=5)
        client_sock.sendall(RESP_OK)

        # Full-duplex relay until both directions finish
        t1 = threading.Thread(target=pump, args=(client_sock, target_sock), daemon=True)
        t2 = threading.Thread(target=pump, args=(target_sock, client_sock), daemon=True)
        t1.start(); t2.start()
        t1.join(); t2.join()
    finally:
        try:
            client_sock.close()
        except Exception:
            pass
        if target_sock:
            try:
                target_sock.close()
            except Exception:
                pass

def main():
    port = int(sys.argv[1])
    srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    srv.bind(("127.0.0.1", port))
    srv.listen(64)
    while True:
        c, _ = srv.accept()
        threading.Thread(target=handle_client, args=(c,), daemon=True).start()

if __name__ == "__main__":
    main()

PY

python3 "$TMPDIR/run_https.py" "$S_PORT" "$TMPDIR/cert.pem" "$TMPDIR/key.pem" >"$TMPDIR/https.log" 2>&1 &
SPID=$!
sleep 0.5

# Start CONNECT proxy
python3 "$TMPDIR/connect_proxy.py" "$P_PORT" >"$TMPDIR/proxy.log" 2>&1 &
PROXY_PID=$!
sleep 0.5

# Use HTTPS through HTTP proxy (ecurl should honor HTTP(S)_PROXY envs)
resp="$(HTTPS_PROXY="http://127.0.0.1:${P_PORT}" ecurl -u "https://127.0.0.1:${S_PORT}/echo?x=" -X POST -d "foo=bar" -H "X-Test: ecurl" -i "tls" --json -q --insecure --timeout 5)"
echo "$resp" | jq -e '.response.status == 200 or .response.status > 0' >/dev/null
