Skip to content
Snippets Groups Projects
node.rs 13.2 KiB
Newer Older
use ollama_rs::generation::completion::request::GenerationRequest;
use std::process::Command;
use serde_json::Value;
use std::net::{
    SocketAddr, 
    UdpSocket
};
use std::{
    thread, 
    time
};
use ollama_rs::*;
use tokio::runtime::Runtime;
use chrono::Utc;
use std::sync::{
    Barrier, 
    Arc
};

CARDILE VINCENT's avatar
CARDILE VINCENT committed
use crate::data::{
    self
};
const MODEL: &str = "mymist:latest";
const LOCALHOST: &str = "http://localhost";

/**
 * Make a node
 * 
 * @param barrier: Arc<Barrier> barrier to sync the threads
 * @param node_id: String node id
 * @param array: Value array of the node
CARDILE VINCENT's avatar
CARDILE VINCENT committed
 * @param chanel: data::Com channel to comunicate with the node
CARDILE VINCENT's avatar
CARDILE VINCENT committed
pub fn make_node(barrier: Arc<Barrier>, node_id: String , array: Value, chanel: data::Com, flags: data::Flags) {

    if flags.debug == true {
        println!("|----------------make_node: {}----------------|", node_id);
    }

    let _child = thread::spawn(move || {
        let llm_string = array[1].get("llm").unwrap().as_str().unwrap();

        if flags.debug == true {
            println!("{} : llm -> {}, {}", node_id, llm_string, llm_string.split(":").collect::<Vec<&str>>()[1]);
        }
        
        let _ =  if cfg!(target_os = "windows") {
            Command::new("cmd")
                .args(&["/C", &format!("docker run -d -v ollama:/root/.ollama -p {}:11434 --name ollama_{} ollama/ollama", llm_string.split(":").collect::<Vec<&str>>()[1], node_id)])
                .output()
                .expect("failed to execute process")
        } else {
            Command::new("sh")
                    .arg("-c")
                    .arg(&format!("docker run -d -v ollama:/root/.ollama -p {}:11434 --name ollama_{} ollama/ollama", llm_string.split(":").collect::<Vec<&str>>()[1], node_id))
                    .output()
                    .expect("failed to execute process")
        };

        if flags.bypass ==  false {
            println!("{} -> arbitrary wait for the container to start", node_id);
            thread::sleep(time::Duration::from_secs(5));    
        }

        let mut _model = String::new();
        if flags.model == true {
            _model = flags.model_name.clone();

            let _ =  if cfg!(target_os = "windows") {
                Command::new("cmd")
                    .args(&["/C", &format!("docker exec --it ollama_{} ollama run {}", node_id, _model)])
                    .output()
                    .expect("failed to execute process")
            } else {
                Command::new("sh")
                        .arg("-c")
                        .arg(&format!("docker exec --it ollama_{} ollama run {}", node_id, _model))
                        .output()
                        .expect("failed to execute process")
            };
            Command::new("sh")
                .arg("-c")
                .arg(&format!("docker cp {} ollama_{}:/root/.ollama/{}", _model, node_id, _model))
                .arg(&format!("docker exec ollama_{} ollama create {} --file /root/.ollama/{}", node_id, _model, _model))
                .expect("failed to make model available to the container");

            if flags.bypass ==  false {
CARDILE VINCENT's avatar
CARDILE VINCENT committed
                    println!("{} -> arbitrary wait for the container to init model", node_id);
                    thread::sleep(time::Duration::from_secs(5));    
                .arg("-c")
                .arg(&format!("docker exec ollama_{} ollama run {}", node_id, _model))
                .output()
                .expect("failed to execute process");
        }
        if flags.bypass ==  false {
            println!("{} -> arbitrary wait for the model to start", node_id);
            thread::sleep(time::Duration::from_secs(5));    
        }

        let node_socket = UdpSocket::bind(array[0].get("node").unwrap().as_str().unwrap()).expect("Could not bind node address");
CARDILE VINCENT's avatar
CARDILE VINCENT committed
        let node_port = node_socket.local_addr().unwrap().port();
        let mut comm_vec: Vec<SocketAddr> = Vec::new();
        
        let llm = Ollama::new(LOCALHOST.to_string(), llm_string.split(":").collect::<Vec<&str>>()[1].parse::<u16>().unwrap());

        node_socket.set_nonblocking(true).expect("Could not set non-blocking mode to node socket");

        if let Some(Value::Array(comm_arr)) = array[2].get("neighbours") {
            for comm in comm_arr {
                if let Some(comm_str) = comm.as_str() {
                    if let Ok(comm_addr) = comm_str.parse::<std::net::SocketAddr>() {
                        comm_vec.push(comm_addr);
                    } else {
                        println!("{} : Invalid comm address: {}", node_id, comm_str);
                    }
                } else {
                    println!("{} : Invalid comm address: {:?}", node_id, comm);
                }
            }
        }

        if flags.debug == true {
            println!("{} : node -> {}", node_id, node_socket.local_addr().unwrap());
            for comm in array[2].get("neighbours").iter() {
                println!("{} : comm -> {}", node_id, comm);
            }
        }

        barrier.wait();

        // node_loop(barrier, node_id, rx, node_socket, llm, array[2].get("neighbours"), debug).unwrap();
CARDILE VINCENT's avatar
CARDILE VINCENT committed
        node_loop(barrier, node_id, chanel, node_socket, llm, comm_vec, flags.debug).unwrap();

    });
}

/*
* Node loop
*
* @param barrier: Arc<Barrier> barrier to sync the threads
* @param node_id: String node id
CARDILE VINCENT's avatar
CARDILE VINCENT committed
* @param chanel: data::Com channel to comunicate with the node
* @param node_socket: std::net::UdpSocket node socket
* @param llm: Ollama Ollama instance
* @param comm_vec: Vec<SocketAddr> vector of neighbours
* @param debug: bool debug flag
*
* @return std::io::Result<()>
*/
CARDILE VINCENT's avatar
CARDILE VINCENT committed
fn node_loop( barrier: Arc<Barrier>, node_id: String, chanel: data::Com, node_socket: std::net::UdpSocket, llm: Ollama,comm_vec: Vec<SocketAddr>, debug: bool) -> std::io::Result<()> {
CARDILE VINCENT's avatar
CARDILE VINCENT committed
    let mut msg_master = false;
CARDILE VINCENT's avatar
CARDILE VINCENT committed
    // let mut data: data::Message = data::Message { message: "".to_string() };
    // let mut response: data::Message = data::Message { message: "".to_string() };
    let mut data: String = "".to_string();
    let mut response: String = "".to_string();
    let start_time = Utc::now().timestamp();
    let rt = Runtime::new().unwrap();

    loop {
        // listen to master
        let elapsed = Utc::now().timestamp() - start_time;
CARDILE VINCENT's avatar
CARDILE VINCENT committed
        match chanel.chan_r.try_recv() {
            Ok(Value::String(msg)) => {
                if debug == true {
                    println!("{} -> {}", node_id, msg);
                }

                let (first, rest) = match msg.split_whitespace().next() {
                    Some(first) => (first, msg.trim_start_matches(first).trim_start()),
                    None => continue,
                };

                if debug == true {
                    println!("{} (first, rest) -> {} {}", node_id, first, rest);
                }

                match first.trim().replace(" ", "").as_str() {
CARDILE VINCENT's avatar
CARDILE VINCENT committed
                    "send" => {
                        msg_received = true;
                        msg_master = true;
                        data = format!("{}{}{}",rest.to_string(), ". You are ", node_id);
CARDILE VINCENT's avatar
CARDILE VINCENT committed
                        if debug == true {
                            println!("{} : send received -> {}", node_id, data);
                        }
                    "quit" => {
                        drop(node_socket);
                        if cfg!(target_os = "windows") {
                            match Command::new("cmd")
                                .args(&["/C", &format!("docker rm -f /ollama_{}", node_id)])
                                .output() {
                                    Ok(_) => {},
                                    Err(e) => println!("{} : Error: {}", node_id, e)
                                }
                        } else {
                            match Command::new("sh")
                                .arg("-c")
                                .arg(&format!("docker rm -f /ollama_{}", node_id))
                                .output() {
                                    Ok(_) => {},
                                    Err(e) => println!("{} : Error: {}", node_id, e)
                                }
                        };

                        println!("child exited");
                        
                        barrier.wait();

                        return Ok(());
                    },
                    _ => {
                        // println!("{} : Invalid command: {}", node_id, first);
CARDILE VINCENT's avatar
CARDILE VINCENT committed
                        data= msg.to_string();
                    }
                    
                }
            },
            Ok(Value::Null) => {
                continue;
            },
            Ok(Value::Bool(_)) => {
                println!("Bool value not supported yet!");
            },
            Ok(Value::Number(_)) => {
                println!("Number value not supported yet!");
            },
            Ok(Value::Array(_)) => {
                println!("Array value not supported yet!");
            },
            Ok(Value::Object(_)) => {
                println!("Object value not supported yet!");
            },
            Err(std::sync::mpsc::TryRecvError::Empty) => {
            },
            Err(e) => println!("Error node_loop: {}", e)
        }

        match node_socket.recv_from(&mut buf) {
            Ok(_) => {
                let json_end = buf.iter().position(|&b| b == b'}');
                if let Some(end) = json_end {
                    let json_slice = &buf[..=end];
CARDILE VINCENT's avatar
CARDILE VINCENT committed
                    data = std::str::from_utf8(json_slice).unwrap().to_string();
                    
                    msg_received = true;
                } else {
                    println!("Invalid JSON data");
                }
            },
            Err(e) => {
                if e.kind() != std::io::ErrorKind::WouldBlock {
                    println!("{} : Error: {}", node_id, e);
                }
            }
        }
        
        rt.block_on(async {
            //listen to neighbours
            if msg_received == true {
                if debug == true {
CARDILE VINCENT's avatar
CARDILE VINCENT committed
                    println!("{:?}:{} request -> {}", elapsed, node_id, data);
CARDILE VINCENT's avatar
CARDILE VINCENT committed
                
                response = llm.generate(GenerationRequest::new(MODEL.to_string(), data.clone())).await.clone().unwrap().response;
CARDILE VINCENT's avatar
CARDILE VINCENT committed
                    println!("{:?}:{} response -> {}", elapsed, node_id, response);
CARDILE VINCENT's avatar
CARDILE VINCENT committed
                let mut value: serde_json::Value = serde_json::Value::Null;
                let mut port_d: u32 = 1;

                match serde_json::from_str(&response) {
                    Ok(v) => {
                        value = v;
                        port_d = value["destination"].as_str().unwrap().parse::<u32>().unwrap();
                        if debug == true {
                            println!("{:?}:{} destination -> {}", elapsed, node_id, port_d);
CARDILE VINCENT's avatar
CARDILE VINCENT committed
                    },
                    Err(_) => {
                        println!("{} : Error: parsing response from LLM {}", node_id, response);
                    }
                }

                if port_d == 1 {
                    println!("{:?}:{} Sending to master", elapsed, node_id);
                    chanel.chan_s.send(serde_json::Value::String(response.to_string())).unwrap();
                } else {    
                    for comm in comm_vec.iter() {
                        let comm_str = comm.to_string();
                        let addr = comm_str.trim_matches(|c| c == '[' || c == ']' || c == '\"');
                        let port = addr.split(":").collect::<Vec<&str>>()[1].parse::<u32>().unwrap();

                        if port == port_d {
                            println!("{:?}:{} sending to -> {}", elapsed, node_id, addr);

                            value["origin"] = serde_json::Value::String(node_socket.local_addr().unwrap().to_string().split(":").collect::<Vec<&str>>()[1].to_string());
                            if msg_master == true {
                                value["identity_master"] = serde_json::Value::String(node_id.to_string());
                            }

                            value["identity"] = serde_json::Value::String(node_id.clone());

                            response = value.to_string();
                            let send = serde_json::to_string(&response).unwrap();

                            match node_socket.send_to(send.as_bytes(), addr){
                                Ok(_) => {
                                    if debug == true {
                                        println!("{:?}:{} sent -> {}", elapsed, node_id, send);
                                    }
                                },
                                Err(e) => {
                                    println!("{} : Error: {}", node_id, e);
                                }
                            }
                        }