You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
447 lines
18 KiB
447 lines
18 KiB
import sys
|
|
import subprocess
|
|
# --- Dependency Check & Auto-Install ---
|
|
required_imports = [
|
|
("getpass", None),
|
|
("json", None),
|
|
("logging", None),
|
|
("os", None),
|
|
("threading", None),
|
|
("time", None),
|
|
("functools", None),
|
|
("multiprocessing", None),
|
|
("pathlib", None),
|
|
("typing", None),
|
|
("urllib.parse", None),
|
|
("pandas", "pandas"),
|
|
("requests", "requests"),
|
|
("tqdm", "tqdm"),
|
|
("pandas.io.formats.style", "pandas"),
|
|
]
|
|
for mod, pipname in required_imports:
|
|
try:
|
|
if "." in mod:
|
|
__import__(mod.split(".")[0])
|
|
else:
|
|
__import__(mod)
|
|
except ImportError:
|
|
if pipname:
|
|
print(f"Installing missing package: {pipname}")
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", pipname])
|
|
else:
|
|
print(f"Module {mod} is a built-in or not installable via pip.")
|
|
# --- Script Description ---
|
|
"""
|
|
Autosimulator for WorldQuant BRAIN platform
|
|
- Timestamped logger
|
|
- Authentication with biometric check
|
|
- User-specified alpha JSON input
|
|
- Single/multi simulation mode
|
|
- Simulation worker: sends jobs, retries, saves locations
|
|
- Result worker: fetches results, saves to JSON
|
|
"""
|
|
import os
|
|
import sys
|
|
import time
|
|
import json
|
|
import threading
|
|
import logging
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
import requests
|
|
import getpass
|
|
|
|
# Platform specific imports
|
|
if sys.platform == 'win32':
|
|
import msvcrt
|
|
else:
|
|
import tty
|
|
import termios
|
|
from ace_lib import (
|
|
check_session_and_relogin,
|
|
simulate_single_alpha,
|
|
simulate_multi_alpha,
|
|
)
|
|
|
|
# --- Logger Setup ---
|
|
def setup_logger():
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
log_filename = f'autosim_{timestamp}.log'
|
|
logger = logging.getLogger(f'autosim_{timestamp}')
|
|
logger.setLevel(logging.DEBUG)
|
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
|
fh = logging.FileHandler(log_filename)
|
|
fh.setFormatter(formatter)
|
|
logger.addHandler(fh)
|
|
ch = logging.StreamHandler()
|
|
ch.setFormatter(formatter)
|
|
logger.addHandler(ch)
|
|
return logger, log_filename
|
|
|
|
logger, log_filename = setup_logger()
|
|
|
|
# --- Authentication ---
|
|
def check_session_timeout(s):
|
|
"""
|
|
Check if the current session has timed out.
|
|
|
|
Args:
|
|
s (SingleSession): The current session object.
|
|
|
|
Returns:
|
|
int: The number of seconds until the session expires, or 0 if the session has expired or an error occurred.
|
|
"""
|
|
brain_api_url = os.environ.get("BRAIN_API_URL", "https://api.worldquantbrain.com")
|
|
authentication_url = brain_api_url + "/authentication"
|
|
try:
|
|
result = s.get(authentication_url).json()["token"]["expiry"]
|
|
logger.debug(f"Session (ID: {id(s)}) timeout check result: {result}")
|
|
return result
|
|
except Exception:
|
|
logger.error(f"Session timeout check failed for session (ID: {id(s)})")
|
|
return 0
|
|
def get_credentials():
|
|
email = input("Email: ").strip()
|
|
print("Password: ", end='', flush=True)
|
|
password = []
|
|
|
|
try:
|
|
if sys.platform == 'win32':
|
|
# Windows: Use msvcrt.getch()
|
|
while True:
|
|
char = msvcrt.getch()
|
|
|
|
# Handle Enter key
|
|
if char in [b'\r', b'\n']:
|
|
print() # New line
|
|
break
|
|
|
|
# Handle Backspace
|
|
elif char == b'\x08': # Backspace
|
|
if password:
|
|
password.pop()
|
|
# Move cursor back, print space, move cursor back again
|
|
print('\b \b', end='', flush=True)
|
|
|
|
# Handle Ctrl+C
|
|
elif char == b'\x03': # Ctrl+C
|
|
print()
|
|
raise KeyboardInterrupt
|
|
|
|
# Handle printable characters (ASCII)
|
|
elif 32 <= ord(char) <= 126: # Printable ASCII range
|
|
password.append(char.decode('ascii'))
|
|
print('*', end='', flush=True)
|
|
|
|
# Handle extended characters
|
|
else:
|
|
try:
|
|
decoded_char = char.decode('utf-8')
|
|
if decoded_char.isprintable():
|
|
password.append(decoded_char)
|
|
print('*', end='', flush=True)
|
|
except UnicodeDecodeError:
|
|
continue
|
|
else:
|
|
# Unix/macOS: Use tty and termios
|
|
fd = sys.stdin.fileno()
|
|
old_settings = termios.tcgetattr(fd)
|
|
try:
|
|
tty.setraw(fd)
|
|
while True:
|
|
char = sys.stdin.read(1)
|
|
|
|
# Handle Enter key
|
|
if char in ['\r', '\n']:
|
|
print('\r\n', end='', flush=True)
|
|
break
|
|
|
|
# Handle Backspace
|
|
elif char in ['\x7f', '\x08']:
|
|
if password:
|
|
password.pop()
|
|
print('\b \b', end='', flush=True)
|
|
|
|
# Handle Ctrl+C
|
|
elif char == '\x03':
|
|
print('\r\n', end='', flush=True)
|
|
raise KeyboardInterrupt
|
|
|
|
# Handle printable characters
|
|
elif char.isprintable():
|
|
password.append(char)
|
|
print('*', end='', flush=True)
|
|
finally:
|
|
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
|
|
|
except Exception as e:
|
|
# Fallback to getpass
|
|
print(f"\nError reading password: {e}")
|
|
print("Falling back to getpass...")
|
|
return (email, getpass.getpass())
|
|
|
|
return (email, ''.join(password))
|
|
|
|
def authenticate():
|
|
from ace_lib import SingleSession
|
|
session = SingleSession()
|
|
session.auth = get_credentials()
|
|
brain_api_url = os.environ.get("BRAIN_API_URL", "https://api.worldquantbrain.com")
|
|
r = session.post(brain_api_url + "/authentication")
|
|
logger.debug(f"New session created (ID: {id(session)}) with authentication response: {r.status_code}, {r.json()}")
|
|
if r.status_code == requests.status_codes.codes.unauthorized:
|
|
if r.headers.get("WWW-Authenticate") == "persona":
|
|
print(
|
|
"Complete biometrics authentication and press any key to continue: \n"
|
|
+ r.url + "/persona?inquiry=" + r.headers.get("Location", "")
|
|
+ "\n"
|
|
)
|
|
input()
|
|
session.post(r.headers.get("Location", r.url))
|
|
while True:
|
|
if session.post(r.headers.get("Location", r.url)).status_code != 201:
|
|
input(
|
|
"Biometrics authentication is not complete. Please try again and press any key when completed \n"
|
|
)
|
|
else:
|
|
break
|
|
else:
|
|
logger.error("\nIncorrect email or password\n")
|
|
return authenticate()
|
|
return session
|
|
|
|
# --- User Input ---
|
|
MASTER_LOG_PATH = "autosim_master_log.json"
|
|
|
|
def update_master_log(input_json_path, latest_index):
|
|
"""
|
|
Update the master log file with the latest successful index for the given input file name.
|
|
"""
|
|
import os, json
|
|
file_name = os.path.basename(input_json_path)
|
|
log_data = {}
|
|
# Read existing log if present
|
|
if os.path.exists(MASTER_LOG_PATH):
|
|
try:
|
|
with open(MASTER_LOG_PATH, "r", encoding="utf-8") as f:
|
|
log_data = json.load(f)
|
|
except Exception:
|
|
log_data = {}
|
|
# Update with latest index
|
|
log_data[file_name] = latest_index
|
|
# Atomic write
|
|
tmp_path = MASTER_LOG_PATH + ".tmp"
|
|
with open(tmp_path, "w", encoding="utf-8") as f:
|
|
json.dump(log_data, f, indent=2)
|
|
os.replace(tmp_path, MASTER_LOG_PATH)
|
|
def get_user_json():
|
|
import re
|
|
while True:
|
|
raw_path = input('Enter path to alpha JSON file: ').strip()
|
|
json_path = re.sub(r'^["\']+|["\']+$', '', raw_path.strip())
|
|
if os.path.exists(json_path):
|
|
try:
|
|
with open(json_path, 'r', encoding='utf-8') as f:
|
|
alpha_list = json.load(f)
|
|
# Check master log for previous progress
|
|
file_name = os.path.basename(json_path)
|
|
start_index = 0
|
|
if os.path.exists(MASTER_LOG_PATH):
|
|
try:
|
|
with open(MASTER_LOG_PATH, 'r', encoding='utf-8') as logf:
|
|
log_data = json.load(logf)
|
|
if file_name in log_data:
|
|
last_index = log_data[file_name]
|
|
print(f'Last time you simulated to position {last_index}.')
|
|
resp = input(f'Do you want to start from {last_index + 1}? (Y/n) Or enter another starting index: ').strip()
|
|
if resp.lower() in ['', 'y', 'yes']:
|
|
start_index = last_index + 1
|
|
elif resp.isdigit():
|
|
start_index = int(resp)
|
|
else:
|
|
print('Invalid input, starting from 0.')
|
|
start_index = 0
|
|
except Exception:
|
|
pass
|
|
# Slice alpha_list to start from chosen index
|
|
class AlphaList(list):
|
|
pass
|
|
alpha_list = AlphaList(alpha_list[start_index:])
|
|
alpha_list._start_index = start_index
|
|
return alpha_list, json_path
|
|
except Exception as e:
|
|
logger.error(f'Error reading JSON file: {e}')
|
|
else:
|
|
logger.error(f'JSON file not found: {json_path}')
|
|
print('Please enter a valid path to your alpha JSON file.')
|
|
|
|
def get_simulation_mode():
|
|
mode = input('Select simulation mode (single/multi): ').strip().lower()
|
|
if mode not in ['single', 'multi']:
|
|
logger.error('Invalid mode. Choose "single" or "multi".')
|
|
sys.exit(1)
|
|
batch_size = None
|
|
if mode == 'multi':
|
|
while True:
|
|
try:
|
|
batch_size = int(input('Enter number of elements per multi-simulation batch (2-10): ').strip())
|
|
if 2 <= batch_size <= 10:
|
|
break
|
|
else:
|
|
print('Batch size must be between 2 and 10.')
|
|
except Exception:
|
|
print('Please enter a valid integer between 2 and 10.')
|
|
return mode, batch_size
|
|
|
|
def get_retry_timeout():
|
|
try:
|
|
timeout = int(input('Enter retry timeout in seconds (default 60): ').strip())
|
|
if timeout < 1:
|
|
timeout = 60
|
|
except Exception:
|
|
timeout = 60
|
|
return timeout
|
|
|
|
# --- Simulation Worker ---
|
|
def simulation_worker(session, alpha_list, mode, json_path, location_path, retry_timeout, batch_size=None):
|
|
locations = {}
|
|
# Initialize sent_count from user starting index (passed via alpha_list attribute if set)
|
|
file_name = os.path.basename(json_path)
|
|
sent_count = getattr(alpha_list, '_start_index', 0)
|
|
while alpha_list:
|
|
# Check session timeout before proceeding
|
|
if check_session_timeout(session) == 0:
|
|
logger.error('Session expired. Stopping simulation worker.')
|
|
break
|
|
session = check_session_and_relogin(session)
|
|
# Prepare batch but do NOT pop yet
|
|
if mode == 'single':
|
|
batch = [alpha_list[0]]
|
|
else:
|
|
size = batch_size if batch_size else min(10, max(2, len(alpha_list)))
|
|
batch = [alpha_list[i] for i in range(min(size, len(alpha_list)))]
|
|
try:
|
|
from ace_lib import start_simulation
|
|
location = None
|
|
while location is None:
|
|
# Check session timeout before each send
|
|
if check_session_timeout(session) == 0:
|
|
logger.error('Session expired. Stopping simulation worker.')
|
|
return
|
|
if mode == 'single':
|
|
response = start_simulation(session, batch[0])
|
|
location = response.headers.get('Location')
|
|
else:
|
|
response = start_simulation(session, batch)
|
|
location = response.headers.get('Location')
|
|
if location is None:
|
|
logger.info(f'Simulation sent, location(s) saved: None')
|
|
logger.info(f'No location received, waiting {retry_timeout} seconds and retrying...')
|
|
time.sleep(retry_timeout)
|
|
# Only pop/remove after location is valid
|
|
if mode == 'single':
|
|
alpha_list.pop(0)
|
|
sent_count += 1
|
|
update_master_log(json_path, sent_count - 1)
|
|
else:
|
|
for _ in range(len(batch)):
|
|
alpha_list.pop(0)
|
|
sent_count += len(batch)
|
|
update_master_log(json_path, sent_count - 1)
|
|
locations[str(time.time())] = location
|
|
with open(location_path, 'w', encoding='utf-8') as f:
|
|
json.dump(locations, f, indent=2)
|
|
# Do NOT overwrite the input JSON file
|
|
logger.info(f'Simulation sent, location(s) saved: {location}')
|
|
except Exception as e:
|
|
logger.error(f'Simulation error: {e}. Retrying in {retry_timeout} seconds.')
|
|
time.sleep(retry_timeout)
|
|
|
|
# --- Result Worker ---
|
|
def result_worker(session, location_path, result_path, poll_interval=30):
|
|
results = {}
|
|
from time import sleep
|
|
while True:
|
|
# Check session timeout before proceeding
|
|
if check_session_timeout(session) == 0:
|
|
logger.error('Session expired. Stopping result worker.')
|
|
break
|
|
session = check_session_and_relogin(session)
|
|
if not os.path.exists(location_path):
|
|
time.sleep(poll_interval)
|
|
continue
|
|
with open(location_path, 'r', encoding='utf-8') as f:
|
|
locations = json.load(f)
|
|
for loc_key, loc_val in locations.items():
|
|
if loc_key in results:
|
|
continue
|
|
if not loc_val or not isinstance(loc_val, str) or not loc_val.startswith('http'):
|
|
logger.error(f'Invalid or missing location for key {loc_key}: {loc_val}')
|
|
continue
|
|
try:
|
|
# Check session timeout before each result fetch
|
|
if check_session_timeout(session) == 0:
|
|
logger.error('Session expired. Stopping result worker.')
|
|
return
|
|
simulation_progress_url = loc_val
|
|
while True:
|
|
simulation_progress = session.get(simulation_progress_url)
|
|
retry_after = simulation_progress.headers.get("Retry-After", 0)
|
|
if float(retry_after) == 0:
|
|
break
|
|
logger.info(f"Sleeping for {retry_after} seconds for location {simulation_progress_url}")
|
|
sleep(float(retry_after))
|
|
sim_json = simulation_progress.json()
|
|
# Multi-simulation: check for children
|
|
if "children" in sim_json and sim_json.get("status") == "COMPLETE":
|
|
child_results = {}
|
|
for child_id in sim_json["children"]:
|
|
child_url = f"https://api.worldquantbrain.com/simulations/{child_id}"
|
|
child_resp = session.get(child_url)
|
|
child_json = child_resp.json()
|
|
alpha_id = child_json.get("alpha")
|
|
if not alpha_id:
|
|
logger.error(f"No alpha_id found for child {child_id}")
|
|
child_results[child_id] = {"error": "No alpha_id found"}
|
|
else:
|
|
alpha = session.get(f"https://api.worldquantbrain.com/alphas/{alpha_id}")
|
|
child_results[child_id] = alpha.json()
|
|
results[loc_key] = {"multi_children": child_results}
|
|
logger.info(f"Multi-simulation results fetched for location {loc_val}")
|
|
else:
|
|
# Single simulation
|
|
alpha_id = sim_json.get("alpha")
|
|
if not alpha_id:
|
|
logger.error(f"No alpha_id found for location {simulation_progress_url}")
|
|
results[loc_key] = {"error": "No alpha_id found"}
|
|
else:
|
|
alpha = session.get(f"https://api.worldquantbrain.com/alphas/{alpha_id}")
|
|
results[loc_key] = alpha.json()
|
|
logger.info(f"Result fetched for location {loc_val}")
|
|
with open(result_path, 'w', encoding='utf-8') as f:
|
|
json.dump(results, f, indent=2)
|
|
except Exception as e:
|
|
logger.error(f'Error fetching result for {loc_val}: {e}')
|
|
time.sleep(poll_interval)
|
|
|
|
# --- Main ---
|
|
def main():
|
|
session = authenticate()
|
|
alpha_list, json_path = get_user_json()
|
|
mode, batch_size = get_simulation_mode()
|
|
retry_timeout = get_retry_timeout()
|
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
location_path = f'autosim_locations_{timestamp}.json'
|
|
result_path = f'autosim_results_{timestamp}.json'
|
|
# Start workers
|
|
sim_thread = threading.Thread(target=simulation_worker, args=(session, alpha_list, mode, json_path, location_path, retry_timeout, batch_size))
|
|
res_thread = threading.Thread(target=result_worker, args=(session, location_path, result_path))
|
|
sim_thread.start()
|
|
res_thread.start()
|
|
sim_thread.join()
|
|
# Result worker runs until all locations processed
|
|
logger.info('Simulation worker finished. Waiting for results...')
|
|
res_thread.join()
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|