From 431045da588c9df87458d43761b43c48a085d391 Mon Sep 17 00:00:00 2001 From: vincent cardile <vincent.cardile2@etu.unistra.fr> Date: Sun, 24 Mar 2024 17:59:04 +0100 Subject: [PATCH] =?UTF-8?q?suite=20factorisation,=20ajout=20d=C3=A9marage?= =?UTF-8?q?=20du=20mod=C3=A8le=20de=20language=20pour=20le=20cas=20ou?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/data.rs | 1 - src/main.rs | 62 +++++++++++++++++++++++++++++++++++++++++++++-------- src/node.rs | 20 +++++++++++++++-- 3 files changed, 71 insertions(+), 12 deletions(-) diff --git a/src/data.rs b/src/data.rs index b753395..c684cb8 100644 --- a/src/data.rs +++ b/src/data.rs @@ -3,7 +3,6 @@ use serde::{ Serialize, Deserialize }; - /* * Node struct * diff --git a/src/main.rs b/src/main.rs index 42b7bec..e571e96 100755 --- a/src/main.rs +++ b/src/main.rs @@ -23,12 +23,24 @@ mod node; */ fn parse_json(arg: String, parsed: &mut Value, debug: bool) -> std::io::Result<()> { - let mut file = File::open(arg)?; // ? catch any error and return it to the caller + if arg == "exit" { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "exit")); + } - let mut content = String::new(); - file.read_to_string(&mut content)?; + let mut _file: Option<File> = None; + match File::open(arg) { + Ok(opened_file) => _file = Some(opened_file), + Err(_) => { + let _ = parse_json(get_input(Some("Invalid path")), parsed, debug); + return Ok(()); + } + } - *parsed = serde_json::from_str(&content)?; + if let Some(ref mut actual_file) = _file { + let mut content = String::new(); + actual_file.read_to_string(&mut content)?; + *parsed = serde_json::from_str(&content)?; + } if debug == true { println!("|----------------parse_json--------------|\n{}",parsed.to_string()); @@ -51,6 +63,27 @@ fn print_help() { println!("----------------------------------------"); } +/** + * Get the input from the user + * + * @return String + */ +fn get_input(msg: Option<&str>) -> String { + loop { + println!("{}", msg.unwrap_or("").to_string()); + let mut input = String::new(); + match stdin().read_line(&mut input){ + Ok(_) => { + return input.trim().to_string(); + }, + Err(_) => { + println!("Error reading input"); + return String::from(""); + }, + } + } +} + /** * Main loop * @@ -68,7 +101,6 @@ fn main_loop(barrier: Arc<Barrier>, nodes: &mut Vec<data::Node>, _debug: bool) - let mut input = String::new(); print_help(); - println!("----------------------------------------"); println!("nodes :"); for node in nodes.iter() { println!(" {}", node.id); @@ -94,6 +126,7 @@ fn main_loop(barrier: Arc<Barrier>, nodes: &mut Vec<data::Node>, _debug: bool) - } }, _ => { + // check if the input is a node command let mut index_to_remove = None; for (index, node) in nodes.iter().enumerate() { let (target, mut command) = input.trim().split_at(input.trim().find(" ").unwrap_or(0)); @@ -101,6 +134,7 @@ fn main_loop(barrier: Arc<Barrier>, nodes: &mut Vec<data::Node>, _debug: bool) - if target == node.id { did_something = true; node.chan.send(Value::String(command.to_string())).unwrap(); + // check if the command is quit, need to remove the node if true if command == "quit" { index_to_remove = Some(index); } @@ -127,7 +161,13 @@ fn main_loop(barrier: Arc<Barrier>, nodes: &mut Vec<data::Node>, _debug: bool) - fn main() -> std::io::Result<()> { - let arg = std::env::args().nth(1).expect(""); + let arg: String; + match std::env::args().nth(1) { + None => { + arg = get_input(Some("Please provide a JSON file as an argument")); + }, + _ => arg = std::env::args().nth(1).unwrap(), + } let arg2 = std::env::args().nth(2); let mut debug = false; if arg2 == Some("-d".to_string()) { @@ -137,7 +177,13 @@ fn main() -> std::io::Result<()> { let mut parsed = serde_json::Value::Null; let mut nodes: Vec<data::Node> = Vec::new(); - parse_json(arg, &mut parsed, debug).unwrap(); // unwrap is used to handle the error + match parse_json(arg, &mut parsed, debug) { + Ok(_) => (), + Err(e) => { + println!("Error parsing JSON file: {}", e.to_string()); + return Ok(()); + } + } let barrier = Arc::new(Barrier::new(parsed.as_object().expect("Expected a JSON object").len() + 1)); // +1 for the main thread @@ -150,8 +196,6 @@ fn main() -> std::io::Result<()> { chan: tx.clone(), }; - // rx.set_nonblocking(true); - if let Value::Array(_array) = value { node::make_node(Arc::clone(&barrier), key.to_string() , value.clone(), rx, debug); // Pass the cloned receiver to make_node } diff --git a/src/node.rs b/src/node.rs index 101dd76..100975f 100644 --- a/src/node.rs +++ b/src/node.rs @@ -46,7 +46,7 @@ pub fn make_node(barrier: Arc<Barrier>, node_id: String , array: Value, rx: std: if debug == true { println!("{} : llm -> {}, {}", node_id, llm_string, llm_string.split(":").collect::<Vec<&str>>()[1]); } - + let _ = if cfg!(target_os = "windows") { Command::new("cmd") .args(&["/C", &format!("docker run -d -v ollama:/root/.ollama -p {}:11434 --name ollama_{} ollama/ollama", llm_string.split(":").collect::<Vec<&str>>()[1], node_id)]) @@ -60,7 +60,23 @@ pub fn make_node(barrier: Arc<Barrier>, node_id: String , array: Value, rx: std: .expect("failed to execute process") }; - println!("{} -> arbitrary wait for the llm to start", node_id); + println!("{} -> arbitrary wait for the container to start", node_id); + thread::sleep(time::Duration::from_secs(5)); + + let _ = if cfg!(target_os = "windows") { + Command::new("cmd") + .args(&["/C", &format!("docker exec --it ollama_{} ollama run mistral:latest", node_id)]) + .output() + .expect("failed to execute process") + } else { + Command::new("sh") + .arg("-c") + .arg(&format!("docker exec --it ollama_{} ollama run mistral:latest", node_id)) + .output() + .expect("failed to execute process") + }; + + println!("{} -> arbitrary wait for the model to start", node_id); thread::sleep(time::Duration::from_secs(5)); let node_socket = UdpSocket::bind(array[0].get("node").unwrap().as_str().unwrap()).expect("Could not bind node address"); -- GitLab