diff options
Diffstat (limited to 'ci-driver')
-rw-r--r-- | ci-driver/Cargo.toml | 30 | ||||
-rw-r--r-- | ci-driver/src/main.rs | 626 |
2 files changed, 656 insertions, 0 deletions
diff --git a/ci-driver/Cargo.toml b/ci-driver/Cargo.toml new file mode 100644 index 0000000..697f929 --- /dev/null +++ b/ci-driver/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "ci-driver" +version = "0.0.1" +authors = [ "iximeow <me@iximeow.net>" ] +license = "0BSD" +edition = "2021" + +[[bin]] +name = "ci_driver" +path = "src/main.rs" + +[dependencies] +ci-lib-core = { path = "../ci-lib-core" } +ci-lib-native = { path = "../ci-lib-native" } + +axum = { version = "*" } +axum-extra = { version = "*", features = ["async-read-body"] } +axum-server = { version = "*", features = ["tls-rustls"] } +axum-macros = "*" +serde_json = "*" +serde = { version = "*", features = ["derive"] } +base64 = "*" +tokio = { version = "*", features = ["full"] } +tokio-stream = "*" +tracing-subscriber = "*" +hyper = "*" +futures-util = "*" +lazy_static = "*" +lettre = "*" +reqwest = "*" diff --git a/ci-driver/src/main.rs b/ci-driver/src/main.rs new file mode 100644 index 0000000..f8ff34c --- /dev/null +++ b/ci-driver/src/main.rs @@ -0,0 +1,626 @@ +use std::process::Command; +use std::collections::HashMap; +use std::sync::{Mutex, RwLock}; +use lazy_static::lazy_static; +use std::io::Read; +use futures_util::StreamExt; +use std::fmt; +use std::path::{Path, PathBuf}; +use tokio::spawn; +use tokio_stream::wrappers::ReceiverStream; +use std::sync::{Arc, Weak}; +use std::time::{SystemTime, UNIX_EPOCH}; +use axum_server::tls_rustls::RustlsConfig; +use axum::body::StreamBody; +use axum::http::{StatusCode}; +use hyper::HeaderMap; +use axum::Router; +use axum::routing::*; +use axum::extract::State; +use axum::extract::BodyStream; +use axum::response::IntoResponse; +use tokio::sync::mpsc; +use tokio::sync::mpsc::error::TrySendError; +use serde_json::json; +use serde::{Deserialize, Serialize}; + +use ci_lib_core::dbctx::DbCtx; +use ci_lib_core::sql; +use ci_lib_core::sql::{PendingRun, Job, Run}; +use ci_lib_core::sql::JobResult; +use ci_lib_core::sql::RunState; +use ci_lib_core::protocol::{ClientProto, CommandInfo, TaskInfo, RequestedJob}; + +lazy_static! { + static ref AUTH_SECRET: RwLock<Option<String>> = RwLock::new(None); + static ref ACTIVE_TASKS: Mutex<HashMap<u64, Weak<()>>> = Mutex::new(HashMap::new()); +} + +fn reserve_artifacts_dir(run: u64) -> std::io::Result<PathBuf> { + let mut path: PathBuf = "/root/ixi_ci_server/artifacts/".into(); + path.push(run.to_string()); + match std::fs::create_dir(&path) { + Ok(()) => { + Ok(path) + }, + Err(e) => { + if e.kind() == std::io::ErrorKind::AlreadyExists { + Ok(path) + } else { + Err(e) + } + } + } +} + +async fn activate_run(dbctx: Arc<DbCtx>, candidate: RunnerClient, job: &Job, run: &PendingRun) -> Result<(), String> { + eprintln!("activating task {:?}", run); + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("now is before epoch") + .as_millis(); + + let remote = dbctx.remote_by_id(job.remote_id).expect("query succeeds").expect("job has remote"); + let repo = dbctx.repo_by_id(remote.repo_id).expect("query succeeds").expect("remote has repo"); + + let commit_sha = dbctx.commit_sha(job.commit_id).expect("query succeeds"); + + let artifacts: PathBuf = reserve_artifacts_dir(run.id).expect("can reserve a directory for artifacts"); + + eprintln!("running {}", &repo.name); + + let res = candidate.submit(&dbctx, &run, &remote.remote_git_url, &commit_sha).await; + + let mut client_job = match res { + Ok(Some(mut client_job)) => { client_job } + Ok(None) => { + return Err("client hung up instead of acking task".to_string()); + } + Err(e) => { + // failed to submit job, move on for now + return Err(format!("failed to submit task: {:?}", e)); + } + }; + + let host_id = client_job.client.host_id; + + let connection = dbctx.conn.lock().unwrap(); + connection.execute( + "update runs set started_time=?1, host_id=?2, state=1, artifacts_path=?3, build_token=?4 where id=?5", + (now as u64, host_id, format!("{}", artifacts.display()), &client_job.client.build_token, run.id) + ) + .expect("can update"); + std::mem::drop(connection); + + spawn(async move { + client_job.run().await + }); + + Ok(()) +} + +struct RunnerClient { + tx: mpsc::Sender<Result<String, String>>, + rx: BodyStream, + host_id: u32, + build_token: String, + accepted_sources: Option<Vec<String>>, +} + +fn token_for_job() -> String { + let mut data = [0u8; 32]; + std::fs::File::open("/dev/urandom") + .unwrap() + .read_exact(&mut data) + .unwrap(); + + base64::encode(data) +} + +struct ClientJob { + dbctx: Arc<DbCtx>, + remote_git_url: String, + sha: String, + task: PendingRun, + client: RunnerClient, + // exists only as confirmation this `ClientJob` is somewhere, still alive and being processed. + task_witness: Arc<()>, +} + +impl ClientJob { + pub async fn run(&mut self) { + loop { + eprintln!("waiting on response.."); + let msg = match self.client.recv_typed::<ClientProto>().await.expect("recv works") { + Some(msg) => msg, + None => { + eprintln!("client hung up. task's done, i hope?"); + return; + } + }; + eprintln!("got {:?}", msg); + match msg { + ClientProto::NewTaskPlease { allowed_pushers, host_info } => { + eprintln!("misdirected task request (after handshake?)"); + return; + } + ClientProto::TaskStatus(task_info) => { + let (result, state): (Result<String, String>, RunState) = match task_info { + TaskInfo::Finished { status } => { + eprintln!("task update: state is finished and result is {}", status); + match status.as_str() { + "pass" => { + (Ok("success".to_string()), RunState::Finished) + }, + other => { + eprintln!("unhandled task completion status: {}", other); + (Err(other.to_string()), RunState::Error) + } + } + }, + TaskInfo::Interrupted { status, description } => { + eprintln!("task update: state is interrupted and result is {}", status); + let desc = description.unwrap_or_else(|| status.clone()); + (Err(desc), RunState::Error) + } + }; + + let job = self.dbctx.job_by_id(self.task.job_id).expect("can query").expect("job exists"); + let repo_id = self.dbctx.repo_id_by_remote(job.remote_id).unwrap().expect("remote exists"); + + for notifier in ci_lib_native::dbctx_ext::notifiers_by_repo(&self.dbctx, repo_id).expect("can get notifiers") { + if let Err(e) = notifier.tell_complete_job(&self.dbctx, repo_id, &self.sha, self.task.id, result.clone()).await { + eprintln!("could not notify {:?}: {:?}", notifier.remote_path, e); + } + } + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("now is before epoch") + .as_millis(); + + let build_result = if result.is_ok() { + JobResult::Pass + } else { + JobResult::Fail + }; + let result_desc = match result { + Ok(msg) => msg, + Err(msg) => msg, + }; + + self.dbctx.conn.lock().unwrap().execute( + "update runs set complete_time=?1, state=?2, build_result=?3, final_status=?4 where id=?5", + (now as u64, state as u64, build_result as u8, result_desc, self.task.id) + ) + .expect("can update"); + } + ClientProto::ArtifactCreate => { + eprintln!("creating artifact"); + self.client.send(serde_json::json!({ + "status": "ok", + "object_id": "10", + })).await.unwrap(); + }, + ClientProto::Metric { name, value } => { + self.dbctx.insert_metric(self.task.id, &name, &value) + .expect("TODO handle metric insert error?"); + } + ClientProto::Command(_command_info) => { + // record information about commands, start/stop, etc. probably also allow + // artifacts to be attached to commands and default to attaching stdout/stderr? + } + other => { + eprintln!("unhandled message {:?}", other); + } + } + } + } +} + +impl RunnerClient { + async fn new(sender: mpsc::Sender<Result<String, String>>, resp: BodyStream, accepted_sources: Option<Vec<String>>, host_id: u32) -> Result<Self, String> { + let token = token_for_job(); + let client = RunnerClient { + tx: sender, + rx: resp, + host_id, + build_token: token, + accepted_sources, + }; + Ok(client) + } + + async fn test_connection(&mut self) -> Result<(), String> { + self.send_typed(&ClientProto::Ping).await?; + let resp = self.recv_typed::<ClientProto>().await?; + match resp { + Some(ClientProto::Pong) => { + Ok(()) + } + Some(other) => { + Err(format!("unexpected connection test response: {:?}", other)) + } + None => { + Err("client hung up".to_string()) + } + } + } + + async fn send(&mut self, msg: serde_json::Value) -> Result<(), String> { + self.send_typed(&msg).await + } + + async fn send_typed<T: serde::Serialize>(&mut self, msg: &T) -> Result<(), String> { + self.tx.send(Ok(serde_json::to_string(msg).unwrap())) + .await + .map_err(|e| e.to_string()) + } + + async fn recv(&mut self) -> Result<Option<serde_json::Value>, String> { + match self.rx.next().await { + Some(Ok(bytes)) => { + serde_json::from_slice(&bytes) + .map(Option::Some) + .map_err(|e| e.to_string()) + }, + Some(Err(e)) => { + eprintln!("e: {:?}", e); + Err(format!("no client job: {:?}", e)) + }, + None => { + Ok(None) + } + } + } + + async fn recv_typed<T: serde::de::DeserializeOwned>(&mut self) -> Result<Option<T>, String> { + let json = self.recv().await?; + Ok(json.map(|v| serde_json::from_value(v).unwrap())) + } + + // is this client willing to run the job based on what it has told us so far? + fn will_accept(&self, job: &Job) -> bool { + match (job.source.as_ref(), self.accepted_sources.as_ref()) { + (_, None) => true, + (None, Some(_)) => false, + (Some(source), Some(accepted_sources)) => { + accepted_sources.contains(source) + } + } + } + + async fn submit(mut self, dbctx: &Arc<DbCtx>, job: &PendingRun, remote_git_url: &str, sha: &str) -> Result<Option<ClientJob>, String> { + self.send_typed(&ClientProto::new_task(RequestedJob { + commit: sha.to_string(), + remote_url: remote_git_url.to_string(), + build_token: self.build_token.to_string(), + })).await?; + match self.recv_typed::<ClientProto>().await { + Ok(Some(ClientProto::Started)) => { + let task_witness = Arc::new(()); + ACTIVE_TASKS.lock().unwrap().insert(job.id, Arc::downgrade(&task_witness)); + Ok(Some(ClientJob { + task: job.clone(), + dbctx: Arc::clone(dbctx), + sha: sha.to_string(), + remote_git_url: remote_git_url.to_string(), + client: self, + task_witness, + })) + } + Ok(Some(resp)) => { + eprintln!("invalid response: {:?}", resp); + Err("client rejected job".to_string()) + } + Ok(None) => { + Ok(None) + } + Err(e) => { + eprintln!("e: {:?}", e); + Err(e) + } + } + } +} + +impl fmt::Debug for RunnerClient { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str("RunnerClient { .. }") + } +} + +#[axum_macros::debug_handler] +async fn handle_artifact(State(ctx): State<(Arc<DbCtx>, mpsc::Sender<RunnerClient>)>, headers: HeaderMap, artifact_content: BodyStream) -> impl IntoResponse { + eprintln!("artifact request"); + let run_token = match headers.get("x-task-token") { + Some(run_token) => run_token.to_str().expect("valid string"), + None => { + eprintln!("bad artifact post: headers: {:?}\nno x-tasak-token", headers); + return (StatusCode::BAD_REQUEST, "").into_response(); + } + }; + + let (run, artifact_path, token_validity) = match ctx.0.run_for_token(&run_token).unwrap() { + Some(result) => result, + None => { + eprintln!("bad artifact post: headers: {:?}\nrun token is not known", headers); + return (StatusCode::BAD_REQUEST, "").into_response(); + } + }; + + if token_validity != sql::TokenValidity::Valid { + eprintln!("bad artifact post: headers: {:?}. token is not valid: {:?}", headers, token_validity); + return (StatusCode::BAD_REQUEST, "").into_response(); + } + + let artifact_path = if let Some(artifact_path) = artifact_path { + artifact_path + } else { + eprintln!("bad artifact post: headers: {:?}. no artifact path?", headers); + return (StatusCode::BAD_REQUEST, "").into_response(); + }; + + let artifact_name = match headers.get("x-artifact-name") { + Some(artifact_name) => artifact_name.to_str().expect("valid string"), + None => { + eprintln!("bad artifact post: headers: {:?}\nno x-artifact-name", headers); + return (StatusCode::BAD_REQUEST, "").into_response(); + } + }; + + let artifact_desc = match headers.get("x-artifact-desc") { + Some(artifact_desc) => artifact_desc.to_str().expect("valid string"), + None => { + eprintln!("bad artifact post: headers: {:?}\nno x-artifact-desc", headers); + return (StatusCode::BAD_REQUEST, "").into_response(); + } + }; + + let mut artifact = match ci_lib_native::dbctx_ext::reserve_artifact(&ctx.0, run, artifact_name, artifact_desc).await { + Ok(artifact) => artifact, + Err(err) => { + eprintln!("failure to reserve artifact: {:?}", err); + return (StatusCode::INTERNAL_SERVER_ERROR, "").into_response(); + } + }; + + eprintln!("spawning task..."); + let dbctx_ref = Arc::clone(&ctx.0); + spawn(async move { + artifact.store_all(artifact_content).await.unwrap(); + dbctx_ref.finalize_artifact(artifact.artifact_id).await.unwrap(); + }); + eprintln!("done?"); + + (StatusCode::OK, "").into_response() +} + +#[derive(Serialize, Deserialize)] +struct WorkRequest { + kind: String, + accepted_pushers: Option<Vec<String>> +} + +async fn handle_next_job(State(ctx): State<(Arc<DbCtx>, mpsc::Sender<RunnerClient>)>, headers: HeaderMap, mut job_resp: BodyStream) -> impl IntoResponse { + let auth_token = match headers.get("authorization") { + Some(token) => { + if Some(token.to_str().unwrap_or("")) != AUTH_SECRET.read().unwrap().as_ref().map(|x| &**x) { + eprintln!("BAD AUTH SECRET SUBMITTED: {:?}", token); + return (StatusCode::BAD_REQUEST, "").into_response(); + } + } + None => { + eprintln!("bad artifact post: headers: {:?}\nno authorization", headers); + return (StatusCode::BAD_REQUEST, "").into_response(); + } + }; + + let (tx_sender, tx_receiver) = mpsc::channel(8); + let resp_body = StreamBody::new(ReceiverStream::new(tx_receiver)); + tx_sender.send(Ok("hello".to_string())).await.expect("works"); + + let request = job_resp.next().await.expect("request chunk").expect("chunk exists"); + let request = std::str::from_utf8(&request).unwrap(); + let request: ClientProto = match serde_json::from_str(&request) { + Ok(v) => v, + Err(e) => { + eprintln!("couldn't parse work request: {:?}", e); + return (StatusCode::MISDIRECTED_REQUEST, resp_body).into_response(); + } + }; + let (accepted_pushers, host_info) = match request { + ClientProto::NewTaskPlease { allowed_pushers, host_info } => (allowed_pushers, host_info), + other => { + eprintln!("bad request kind: {:?}", &other); + return (StatusCode::MISDIRECTED_REQUEST, resp_body).into_response(); + } + }; + + eprintln!("client identifies itself as {:?}", host_info); + + let host_info_id = ctx.0.id_for_host(&host_info).expect("can get a host info id"); + + let client = match RunnerClient::new(tx_sender, job_resp, accepted_pushers, host_info_id).await { + Ok(v) => v, + Err(e) => { + eprintln!("unable to register client"); + return (StatusCode::MISDIRECTED_REQUEST, resp_body).into_response(); + } + }; + + match ctx.1.try_send(client) { + Ok(()) => { + eprintln!("client requested work..."); + return (StatusCode::OK, resp_body).into_response(); + } + Err(TrySendError::Full(client)) => { + return (StatusCode::IM_A_TEAPOT, resp_body).into_response(); + } + Err(TrySendError::Closed(client)) => { + panic!("client holder is gone?"); + } + } +} + +async fn make_api_server(dbctx: Arc<DbCtx>) -> (Router, mpsc::Receiver<RunnerClient>) { + let (pending_client_sender, pending_client_receiver) = mpsc::channel(8); + + let router = Router::new() + .route("/api/next_job", post(handle_next_job)) + .route("/api/artifact", post(handle_artifact)) + .with_state((dbctx, pending_client_sender)); + (router, pending_client_receiver) +} + +#[derive(Deserialize, Serialize)] +struct DriverConfig { + cert_path: PathBuf, + key_path: PathBuf, + config_path: PathBuf, + db_path: PathBuf, + server_addr: String, + auth_secret: String, +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + + let mut args = std::env::args(); + args.next().expect("first arg exists"); + let config_path = args.next().unwrap_or("./driver_config.json".to_string()); + let driver_config: DriverConfig = serde_json::from_reader(std::fs::File::open(config_path).expect("file exists and is accessible")).expect("valid json for DriverConfig"); + let mut auth_secret = AUTH_SECRET.write().unwrap(); + *auth_secret = Some(driver_config.auth_secret.clone()); + std::mem::drop(auth_secret); + + let config = RustlsConfig::from_pem_file( + driver_config.cert_path.clone(), + driver_config.key_path.clone(), + ).await.unwrap(); + + let dbctx = Arc::new(DbCtx::new(&driver_config.config_path, &driver_config.db_path)); + + dbctx.create_tables().unwrap(); + + let (api_server, mut channel) = make_api_server(Arc::clone(&dbctx)).await; + spawn(axum_server::bind_rustls(driver_config.server_addr.parse().unwrap(), config) + .serve(api_server.into_make_service())); + + spawn(old_task_reaper(Arc::clone(&dbctx))); + + loop { + let mut candidate = match channel.recv().await + .ok_or_else(|| "client channel disconnected".to_string()) { + + Ok(candidate) => { candidate }, + Err(e) => { eprintln!("client error: {}", e); continue; } + }; + + let dbctx = Arc::clone(&dbctx); + spawn(async move { + let host_id = candidate.host_id; + let res = find_client_task(dbctx, candidate).await; + eprintln!("task client for {}: {:?}", host_id, res); + }); + } +} + +async fn find_client_task(dbctx: Arc<DbCtx>, mut candidate: RunnerClient) -> Result<(), String> { + let find_client_task_start = std::time::Instant::now(); + + let (run, job) = 'find_work: loop { + // try to find a job for this candidate: + // * start with pending runs - these need *some* client to run them, but do not care which + // * if no new jobs, maybe an existing job still needs a rerun on this client? + // * otherwise, um, i dunno. do nothing? + + let runs = dbctx.get_pending_runs(Some(candidate.host_id)).unwrap(); + + if runs.len() > 0 { + println!("{} new runs", runs.len()); + + for run in runs.into_iter() { + let job = dbctx.job_by_id(run.job_id).expect("can query").expect("job exists"); + + if candidate.will_accept(&job) { + break 'find_work (run, job); + } + + } + } + + let alt_run_jobs = dbctx.jobs_needing_task_runs_for_host(candidate.host_id as u64).expect("can query"); + + for job in alt_run_jobs.into_iter() { + if candidate.will_accept(&job) { + let run = dbctx.new_run(job.id, Some(candidate.host_id)).unwrap(); + break 'find_work (run, job); + } + } + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + if candidate.test_connection().await.is_err() { + return Err("lost client connection".to_string()); + } + + if find_client_task_start.elapsed().as_secs() > 300 { + return Err("5min new task deadline elapsed".to_string()); + } + }; + + eprintln!("enqueueing job {} for alternate run under host id {}", job.id, candidate.host_id); + activate_run(Arc::clone(&dbctx), candidate, &job, &run).await?; + + Ok(()) +} + +async fn old_task_reaper(dbctx: Arc<DbCtx>) { + let mut potentially_stale_tasks = dbctx.get_active_runs().unwrap(); + + let active_tasks = ACTIVE_TASKS.lock().unwrap(); + + for (id, witness) in active_tasks.iter() { + if let Some(idx) = potentially_stale_tasks.iter().position(|task| task.id == *id) { + potentially_stale_tasks.swap_remove(idx); + } + } + + std::mem::drop(active_tasks); + + // ok, so we have tasks that are not active, now if the task is started we know someone should + // be running it and they are not. retain only those tasks, as they are ones we may want to + // mark dead. + // + // further, filter out any tasks created in the last 60 seconds. this is a VERY generous grace + // period for clients that have accepted a job but for some reason we have not recorded them as + // active (perhaps they are slow to ack somehow). + + let stale_threshold = ci_lib_core::now_ms() - 60_000; + + let stale_tasks: Vec<Run> = potentially_stale_tasks.into_iter().filter(|task| { + match (task.state, task.start_time) { + // `run` is atomically set to `Started` and adorned with a `start_time`. disagreement + // between the two means this run is corrupt and should be reaped. + (RunState::Started, None) => { + true + }, + (RunState::Started, Some(start_time)) => { + start_time < stale_threshold + } + // and if it's not `started`, it's either pending (not assigned yet, so not stale), or + // one of the complete statuses. + _ => { + false + } + } + }).collect(); + + for task in stale_tasks.iter() { + eprintln!("looks like task {} is stale, reaping", task.id); + dbctx.reap_task(task.id).expect("works"); + } +} |