use std::{collections::HashMap, net::{Shutdown, TcpStream}, io::{Read, Write}, fs}; use native_tls::{TlsConnector, TlsStream}; use base64::{engine::general_purpose::URL_SAFE, Engine as _}; use crate::http_client::html_parser::{CSSCache, CSSParser, CSSSelector, HTML, Node, get_inline_styles, tree_to_vec}; use serde_json; pub enum Connection { Plain(TcpStream), Tls(TlsStream), } impl Connection { fn shutdown(&mut self, how: Shutdown) -> std::io::Result<()> { match self { Connection::Plain(s) => s.shutdown(how), Connection::Tls(s) => {let _ = s.shutdown(); Ok(())}, } } } impl Read for Connection { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { match self { Connection::Plain(s) => s.read(buf), Connection::Tls(s) => s.read(buf), } } } impl Write for Connection { fn write(&mut self, buf: &[u8]) -> std::io::Result { match self { Connection::Plain(s) => s.write(buf), Connection::Tls(s) => s.write(buf), } } fn flush(&mut self) -> std::io::Result<()> { match self { Connection::Plain(s) => s.flush(), Connection::Tls(s) => s.flush(), } } } pub fn resolve_url(scheme: &str, host: &str, port: u16, path: &str, url: &str) -> String { let mut new_url = url; if new_url.contains("://") { return new_url.to_string(); } let resolved_path = if !new_url.starts_with("/") { let mut dir = path.rsplitn(2, '/').nth(1).unwrap_or(""); while new_url.starts_with("../") { new_url = new_url.strip_prefix("../").unwrap(); if dir.contains('/') { dir = dir.rsplitn(2, '/').nth(1).unwrap_or(""); } } format!("{}/{}", dir, new_url) } else { new_url.to_string() }; if resolved_path.starts_with("//") { format!("{}:{}", scheme, resolved_path) } else { format!("{}://{}:{}{}", scheme, host, port, resolved_path) } } pub struct HTTPClient { pub scheme: String, pub host: String, pub path: String, pub port: u16, pub request_headers: HashMap, pub response_explanation: Option, pub response_headers: HashMap, pub response_http_version: Option, pub response_status: Option, pub node: Option, pub css_rules: Vec<(CSSSelector, HashMap)>, pub content_response: String, pub view_source: bool, pub redirect_count: u32, pub needs_render: bool, pub tcp_stream: Option } impl HTTPClient { pub fn new() -> HTTPClient { HTTPClient { scheme: String::new(), host: String::new(), path: String::new(), port: 0, request_headers: HashMap::new(), response_explanation: None, response_headers: HashMap::new(), response_http_version: None, response_status: None, node: None, css_rules: Vec::new(), content_response: String::new(), view_source: false, redirect_count: 0, needs_render: false, tcp_stream: None } } pub fn file_request(&mut self, url: &String) { self.content_response = fs::read_to_string(url.split_once("file://").unwrap().1).unwrap(); } pub fn get_request(&mut self, url: &String, headers: HashMap, css: bool) { let mut parsed_url = url.clone(); if parsed_url.starts_with("view-source:") { parsed_url = parsed_url.split_once("view-source:").unwrap().1.to_string(); self.view_source = true; } else { self.view_source = false; } let (scheme_str, parsed_url_parts) = parsed_url.split_once("://").unwrap(); self.scheme = scheme_str.to_string(); if !(parsed_url_parts.contains("/")) { self.host = parsed_url_parts.to_string(); self.path = "/".to_string(); } else { let (host_str, path_str) = parsed_url_parts.split_once("/").unwrap(); self.host = host_str.to_string(); self.path = format!("/{}", path_str.to_string()); } if self.host.contains(":") { let temp_host = self.host.clone(); let (host_str, port_str) = temp_host.split_once(":").unwrap(); self.host = host_str.to_string(); self.port = port_str.parse().unwrap(); } else { if self.scheme == "http" { self.port = 80; } else { self.port = 443; } } self.request_headers = headers; self.response_explanation = None; self.response_headers = HashMap::new(); self.response_http_version = None; self.response_status = None; self.content_response = "".to_string(); self.tcp_stream = None; if self.request_headers.contains_key("Host") { self.request_headers.remove("Host"); } self.request_headers.insert("Host".to_string(), self.host.clone()); let html_cache_key = URL_SAFE.encode(format!("{}_{}_{}_{}", self.scheme, self.host, self.port, self.path).as_bytes()); let html_cache_path = format!("html_cache/{}.html", html_cache_key); if std::fs::exists(html_cache_path.clone()).unwrap() { self.content_response = fs::read_to_string(html_cache_path).unwrap(); self.parse(); return; } let tcp = TcpStream::connect(format!("{}:{}", self.host, self.port.to_string())).unwrap(); if self.scheme == "https" { let connector = TlsConnector::new().unwrap(); self.tcp_stream = Some(Connection::Tls(connector.connect(self.host.as_str(), tcp).unwrap())); } else { self.tcp_stream = Some(Connection::Plain(tcp)); } let request_header_lines: String = self.request_headers .iter() .map(|(header_name, header_value)|{ format!("{}: {}", header_name, header_value) }) .collect::>() .join("\r\n"); let request = format!("GET {} HTTP/1.0\r\n{}\r\n\r\n", self.path, request_header_lines); self.tcp_stream.as_mut().unwrap().write_all(request.as_bytes()).unwrap(); self.receive_response(css); // TODO: use threading } fn receive_response(&mut self, css: bool) { let mut temp_buffer = [0; 16384]; let mut headers_parsed: bool = false; let mut content_length: Option = None; loop { let bytes_read = self.tcp_stream.as_mut().unwrap().read(&mut temp_buffer).unwrap_or(0); if bytes_read == 0 { println!("Connection closed by peer."); break; } if !headers_parsed { let header_end_index = temp_buffer[..bytes_read].windows(4).position(|window| {window == b"\r\n\r\n"}); if let Some(header_end_index) = header_end_index { let header_data = std::str::from_utf8(&temp_buffer[..header_end_index]).unwrap_or(""); let body_data = &temp_buffer[header_end_index + 4..bytes_read]; // +4 for the \r\n\r\n self._parse_headers(header_data.to_string()); headers_parsed = true; let content_length_header = self.response_headers.get("content-length"); if let Some(content_length_header) = content_length_header { content_length = Some(content_length_header.parse().unwrap()); } self.content_response = std::str::from_utf8(&body_data).unwrap_or("").to_string(); // Assuming body is UTF-8 if !content_length.is_none() && body_data.len() >= content_length.unwrap() { break; } else if content_length.is_none() {} } else { continue; } } else { self.content_response.push_str(std::str::from_utf8(&temp_buffer[..bytes_read]).unwrap_or("")); if !content_length.is_none() && self.content_response.len() >= content_length.unwrap() { break; } } }; if let Some(ref mut stream) = self.tcp_stream { stream.shutdown(Shutdown::Both).ok(); } self.tcp_stream = None; if 300 <= self.response_status.unwrap() && self.response_status.unwrap() < 400 { if self.redirect_count >= 4 { return; } self.redirect_count += 1; let headers = self.request_headers.clone(); let location = self.response_headers.get("location") .cloned() .unwrap_or("/".to_string()); if location.starts_with("http") || location.starts_with("https") { self.get_request(&location, headers, false); } else { self.get_request(&format!("{}://{}{}", self.scheme, self.host, location), headers, false); } } else { self.redirect_count = 0; } if !css { if !(300..400).contains(&self.response_status.unwrap_or(0)) { self.parse(); } } } fn _parse_headers(&mut self, header_data: String) { let lines: Vec<&str> = header_data.lines().collect(); if lines.is_empty() { println!("Received empty header data."); return } let response_status_line = lines[0]; let mut parts = response_status_line.splitn(3, ' '); self.response_http_version = Some(parts.next().unwrap().to_string()); self.response_status = Some(parts.next().unwrap().parse().unwrap()); let explanation_parts: Vec<&str> = parts.collect(); self.response_explanation = Some(explanation_parts.join(" ")); let mut headers = HashMap::new(); for i in 1..lines.len() { let line = &lines[i]; if line.is_empty() { break; } let (header_name, value) = line.split_once(":").unwrap(); headers.insert(header_name.trim().to_lowercase().to_string(), value.trim().to_string()); } self.response_headers = headers; } pub fn parse(&mut self) { self.css_rules.clear(); let html_cache_key = URL_SAFE.encode(format!("{}_{}_{}_{}", self.scheme, self.host, self.port, self.path).as_bytes()); let html_cache_path = format!("html_cache/{}.html", html_cache_key); if std::fs::exists(html_cache_path.clone()).unwrap() { self.content_response = std::fs::read_to_string(html_cache_path).unwrap(); } else { let _ = std::fs::write(html_cache_path, self.content_response.clone()); } let original_scheme = self.scheme.clone(); let original_host = self.host.clone(); let original_port = self.port; let original_path = self.path.clone(); let original_response = self.content_response.clone(); self.node = Some(Node::Element(HTML::new(self.content_response.clone()).parse())); let mut flattened_tree = vec![]; tree_to_vec(self.node.as_ref().unwrap(), &mut flattened_tree); let css_links: Vec = flattened_tree.iter() .filter(|node| { matches!(node, Node::Element(_)) && node.tag().unwrap() == "link".to_string() && node.attributes().unwrap().get("rel").unwrap() == &"stylesheet".to_string() && node.attributes().unwrap().get("href").is_some() }) .map(|node: &&Node| { node.attributes().unwrap()["href"].clone() }).collect(); for css_link in css_links { self.content_response.clear(); // we need to include the other variables so for example /styles.css wouldnt be cached for all websites let css_cache_key = URL_SAFE.encode(format!("{}_{}_{}_{}", self.scheme, self.host, self.port, css_link).as_bytes()); let css_cache_path = format!("css_cache/{}.json", css_cache_key); let rules: Vec<(CSSSelector, HashMap)> = if std::path::Path::new(&css_cache_path).exists() { let css_cache_content = std::fs::read_to_string(&css_cache_path).unwrap(); let json: CSSCache = serde_json::from_str(&css_cache_content).unwrap(); json.css_cache } else { let resolved = resolve_url(self.scheme.as_str(), self.host.as_str(), self.port, self.path.as_str(), css_link.as_str()); let headers = self.request_headers.clone(); self.get_request(&resolved, headers, true); let parsed_css = CSSParser::new(self.content_response.clone()).parse(); let json = CSSCache { css_cache: parsed_css }; let _ = std::fs::write(&css_cache_path, serde_json::to_string(&json).unwrap()); json.css_cache }; self.css_rules.extend(rules); } self.css_rules.extend(get_inline_styles(self.node.as_ref().unwrap())); self.scheme = original_scheme; self.host = original_host; self.port = original_port; self.path = original_path; self.content_response = original_response; self.needs_render = true; } }