1use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
58use std::ops::RangeInclusive;
59use std::str::FromStr;
60use std::sync::{Arc, RwLock};
61
62use ipnet::{Ipv4Net, Ipv6Net};
63use iprange::IpRange;
64
65#[derive(Debug, thiserror::Error)]
67pub enum RuleParseError {
68 #[error("invalid connection direction: {0}")]
69 Direction(String),
70 #[error("failed to parse int: {0}")]
71 InvalidInteger(#[from] std::num::ParseIntError),
72 #[error("failed to parse IP range address: {0}")]
73 InvalidIpRange(#[from] ipnet::AddrParseError),
74 #[error("failed to parse IP address: {0}")]
75 InvalidIpAddr(#[from] std::net::AddrParseError),
76 #[error("missing colon in rule: {0}")]
77 MissingColon(String),
78 #[error("Single IPV6 entry is not enclosed in brackets: {0}")]
79 MalformedIpv6(String),
80 #[error("Invalid rule type: {0}. Rule type must be either dns, ipv4, or ipv6")]
81 InvalidRuleType(String),
82 #[error("Invalid rule action: {0}. Rule action must be either allow or deny")]
83 InvalidRuleAction(String),
84 #[error("Domain rule not found for: {0}")]
85 DomainRuleNotFound(String),
86 #[error("Domain rule already expanded: {0}")]
87 DomainAlreadyExpanded(String),
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub enum Direction {
93 Inbound,
94 Outbound,
95 Bidirectional,
96}
97
98impl Direction {
99 pub fn matches(&self, direction: Direction) -> bool {
100 *self == Direction::Bidirectional || *self == direction
101 }
102}
103
104impl FromStr for Direction {
105 type Err = RuleParseError;
106
107 fn from_str(s: &str) -> Result<Self, Self::Err> {
108 let direction = if s == "in" {
109 Direction::Inbound
110 } else if s == "out" {
111 Direction::Outbound
112 } else {
113 return Err(RuleParseError::Direction(s.to_string()));
114 };
115
116 Ok(direction)
117 }
118}
119
120#[derive(Debug, Clone, PartialEq, Eq)]
122pub enum PortSpec {
123 All,
125 Port(u16),
127 PortRange(RangeInclusive<u16>),
129}
130
131impl PortSpec {
132 pub fn matches(&self, port: u16) -> bool {
133 match self {
134 PortSpec::All => true,
135 PortSpec::Port(allowed_port) => *allowed_port == port,
136 PortSpec::PortRange(allowed_port_range) => allowed_port_range.contains(&port),
137 }
138 }
139}
140
141impl FromStr for PortSpec {
142 type Err = RuleParseError;
143
144 fn from_str(s: &str) -> Result<Self, Self::Err> {
145 let rule = if s == "*" {
146 PortSpec::All
147 } else if s.contains('-') {
148 let (start, end) = s.split_once('-').unwrap();
149
150 let (start, end) = (start.parse()?, end.parse()?);
151
152 PortSpec::PortRange(start..=end)
153 } else {
154 PortSpec::Port(s.parse()?)
155 };
156
157 Ok(rule)
158 }
159}
160
161#[derive(Debug, Clone, PartialEq, Eq)]
163pub enum DomainSpec {
164 All,
166 Domain(String),
168 DomainGlob(String),
170}
171
172impl DomainSpec {
173 pub fn matches(&self, domain: impl AsRef<str>) -> bool {
174 let domain = domain.as_ref();
175
176 match self {
177 DomainSpec::All => true,
178 DomainSpec::Domain(allowed_domain) => allowed_domain == domain,
179 DomainSpec::DomainGlob(domain_glob) => domain.ends_with(domain_glob),
180 }
181 }
182}
183
184impl FromStr for DomainSpec {
185 type Err = RuleParseError;
186
187 fn from_str(s: &str) -> Result<Self, Self::Err> {
188 let spec = if s == "*" {
189 DomainSpec::All
190 } else if let Some(glob) = s.strip_prefix('*') {
191 DomainSpec::DomainGlob(glob.to_string())
192 } else {
193 DomainSpec::Domain(s.to_string())
194 };
195
196 Ok(spec)
197 }
198}
199
200#[derive(Debug, Clone, PartialEq, Eq)]
202pub struct DNSRule {
203 domain: DomainSpec,
205 port: PortSpec,
207 expanded: bool,
210}
211
212impl DNSRule {
213 pub fn allows(&self, domain: impl AsRef<str>) -> bool {
215 self.domain.matches(domain)
216 }
217
218 pub fn allowed_ports(&self) -> PortSpec {
220 self.port.clone()
221 }
222}
223
224#[derive(Debug, Clone, PartialEq, Eq)]
226pub enum IPV4Spec {
227 All,
229 IP(Ipv4Addr),
231 IPRange(IpRange<Ipv4Net>),
233}
234
235impl IPV4Spec {
236 pub fn matches(&self, ip: impl Into<Ipv4Addr>) -> bool {
237 let ip = ip.into();
238
239 match self {
240 IPV4Spec::All => true,
241 IPV4Spec::IP(allowed_ip) => *allowed_ip == ip,
242 IPV4Spec::IPRange(allowed_ip_range) => allowed_ip_range.contains(&ip),
243 }
244 }
245}
246
247impl FromStr for IPV4Spec {
248 type Err = RuleParseError;
249
250 fn from_str(s: &str) -> Result<Self, Self::Err> {
251 let spec = if s == "*" {
252 IPV4Spec::All
253 } else if s.contains('/') {
254 let ip = Ipv4Net::from_str(s)?;
255 let mut ip_range = IpRange::<Ipv4Net>::new();
256 ip_range.add(ip);
257
258 IPV4Spec::IPRange(ip_range)
259 } else {
260 IPV4Spec::IP(s.parse()?)
261 };
262
263 Ok(spec)
264 }
265}
266
267#[derive(Debug, Clone, PartialEq, Eq)]
269pub struct IPV4Rule {
270 ip_spec: IPV4Spec,
272 port_spec: PortSpec,
274 direction: Direction,
276}
277
278impl IPV4Rule {
279 pub fn is_allowed(&self, ip: impl Into<Ipv4Addr>, port: u16, dir: Direction) -> bool {
280 let ip = ip.into();
281
282 self.ip_spec.matches(ip) && self.port_spec.matches(port) && self.direction.matches(dir)
283 }
284}
285
286#[derive(Debug, Clone, PartialEq, Eq)]
288pub enum IPV6Spec {
289 All,
291 IP(Ipv6Addr),
293 IPRange(IpRange<Ipv6Net>),
295}
296
297impl IPV6Spec {
298 pub fn matches(&self, ip: Ipv6Addr) -> bool {
299 match self {
300 IPV6Spec::All => true,
301 IPV6Spec::IP(allowed_ip) => *allowed_ip == ip,
302 IPV6Spec::IPRange(allowed_ip_range) => allowed_ip_range.contains(&ip),
303 }
304 }
305}
306
307impl FromStr for IPV6Spec {
308 type Err = RuleParseError;
309
310 fn from_str(s: &str) -> Result<Self, Self::Err> {
311 let spec = if s == "*" {
312 IPV6Spec::All
313 } else if s.contains('/') {
314 let ip = Ipv6Net::from_str(s)?;
315 let mut ip_range = IpRange::<Ipv6Net>::new();
316 ip_range.add(ip);
317
318 IPV6Spec::IPRange(ip_range)
319 } else {
320 IPV6Spec::IP(s.parse()?)
321 };
322
323 Ok(spec)
324 }
325}
326
327#[derive(Debug, Clone, PartialEq, Eq)]
329pub struct IPV6Rule {
330 ip_spec: IPV6Spec,
332 port_spec: PortSpec,
334 direction: Direction,
336}
337
338impl IPV6Rule {
339 pub fn is_allowed(&self, ip: impl Into<Ipv6Addr>, port: u16, dir: Direction) -> bool {
340 let ip = ip.into();
341
342 self.ip_spec.matches(ip) && self.port_spec.matches(port) && self.direction.matches(dir)
343 }
344}
345
346#[derive(Debug, Clone, PartialEq, Eq)]
348pub enum Rule {
349 IPV4(IPV4Rule),
351 IPV6(IPV6Rule),
353 DNS(DNSRule),
355 Neg(Arc<Rule>),
357}
358
359impl Rule {
360 pub fn allows_socket(&self, socket_addr: SocketAddr, direction: Direction) -> bool {
362 let ip = socket_addr.ip();
363 let port = socket_addr.port();
364
365 match (self, ip) {
366 (Rule::IPV4(rule), IpAddr::V4(ip)) => rule.is_allowed(ip, port, direction),
367 (Rule::IPV6(rule), IpAddr::V6(ip)) => rule.is_allowed(ip, port, direction),
368 _ => false,
369 }
370 }
371
372 pub fn allows_domain(&self, domain: impl AsRef<str>) -> bool {
374 if let Rule::DNS(rule) = self {
375 rule.allows(domain)
376 } else {
377 false
378 }
379 }
380
381 pub fn blocks_socket(&self, socket_addr: SocketAddr, direction: Direction) -> bool {
383 if let Rule::Neg(rule) = self {
384 rule.allows_socket(socket_addr, direction)
385 } else {
386 false
387 }
388 }
389
390 pub fn blocks_domain(&self, domain: impl AsRef<str>) -> bool {
392 if let Rule::Neg(rule) = self {
393 rule.allows_domain(domain)
394 } else {
395 false
396 }
397 }
398
399 pub fn port_spec_of_domain(&mut self, domain: impl AsRef<str>) -> Option<PortSpec> {
401 if let Rule::DNS(rule) = self {
402 if rule.allows(domain) {
403 return Some(rule.allowed_ports());
404 }
405 }
406
407 None
408 }
409
410 pub fn is_expandable(&self) -> bool {
412 if let Rule::DNS(rule) = self {
413 !rule.expanded
414 } else {
415 false
416 }
417 }
418
419 pub fn set_expanded(&mut self, expanded: bool) {
421 if let Rule::DNS(rule) = self {
422 rule.expanded = expanded;
423 }
424 }
425}
426
427fn parse_enclosed(s: &str, left: char, right: char) -> Option<&str> {
428 match (s.find(left), s.rfind(right)) {
429 (Some(left_idx), Some(right_idx)) if left_idx < right_idx => {
430 Some(&s[left_idx + 1..right_idx])
431 }
432 _ => None,
433 }
434}
435
436fn parse_as_list<T: FromStr<Err = RuleParseError>>(s: &str) -> Result<Vec<T>, RuleParseError> {
437 let entries = if let Some(entries) = parse_enclosed(s, '{', '}') {
438 entries
439 .split(',')
440 .map(|s| s.trim().parse())
441 .collect::<Result<Vec<_>, _>>()?
442 } else {
443 let entry = T::from_str(s)?;
444
445 vec![entry]
446 };
447
448 Ok(entries)
449}
450
451fn parse_ipv4_rule(s: &str) -> Result<Vec<IPV4Rule>, RuleParseError> {
452 let (ips, ports_and_direction) = s
453 .split_once(':')
454 .ok_or_else(|| RuleParseError::MissingColon(s.to_string()))?;
455
456 let mut direction = Direction::Bidirectional;
457 let ports = if let Some((ports, dir)) = ports_and_direction.split_once('/') {
458 direction = dir.parse()?;
459
460 ports
461 } else {
462 ports_and_direction
463 };
464
465 let mut rules = Vec::new();
466 let ips = parse_as_list::<IPV4Spec>(ips)?;
467 let ports = parse_as_list::<PortSpec>(ports)?;
468
469 for ip in &ips {
470 for port in &ports {
471 rules.push(IPV4Rule {
472 ip_spec: ip.clone(),
473 port_spec: port.clone(),
474 direction,
475 });
476 }
477 }
478
479 Ok(rules)
480}
481
482fn parse_ipv6_rule(s: &str) -> Result<Vec<IPV6Rule>, RuleParseError> {
483 let (ips, ports_and_direction) = s
484 .rsplit_once(':')
485 .ok_or_else(|| RuleParseError::MissingColon(s.to_string()))?;
486
487 let mut direction = Direction::Bidirectional;
488 let ports = if let Some((ports, dir)) = ports_and_direction.split_once('/') {
489 direction = dir.parse()?;
490
491 ports
492 } else {
493 ports_and_direction
494 };
495
496 let mut rules = Vec::new();
497
498 let ips = if ips.contains('[') {
499 let ip = parse_enclosed(ips, '[', ']')
500 .ok_or_else(|| RuleParseError::MalformedIpv6(ips.to_string()))?;
501
502 vec![ip.parse::<IPV6Spec>()?]
503 } else {
504 parse_as_list::<IPV6Spec>(ips)?
505 };
506 let ports = parse_as_list::<PortSpec>(ports)?;
507
508 for ip in &ips {
509 for port in &ports {
510 rules.push(IPV6Rule {
511 ip_spec: ip.clone(),
512 port_spec: port.clone(),
513 direction,
514 });
515 }
516 }
517
518 Ok(rules)
519}
520
521fn parse_dns_rule(s: &str) -> Result<Vec<DNSRule>, RuleParseError> {
522 let (domains, ports) = s
523 .split_once(':')
524 .ok_or_else(|| RuleParseError::MissingColon(s.to_string()))?;
525
526 let mut rules = Vec::new();
527 let domains = parse_as_list::<DomainSpec>(domains)?;
528 let ports = parse_as_list::<PortSpec>(ports)?;
529
530 for domain in &domains {
531 for port in &ports {
532 rules.push(DNSRule {
533 domain: domain.clone(),
534 port: port.clone(),
535 expanded: false,
536 });
537 }
538 }
539
540 Ok(rules)
541}
542
543#[derive(Debug, Clone, PartialEq, Eq)]
545enum RuleType {
546 Dns,
547 IPV4,
548 IPV6,
549}
550
551impl RuleType {
552 pub fn consume_input(input: &str) -> Result<(Self, &str), RuleParseError> {
557 let pair = if let Some(rem) = input.strip_prefix("dns:") {
558 (RuleType::Dns, rem)
559 } else if let Some(rem) = input.strip_prefix("ipv4:") {
560 (RuleType::IPV4, rem)
561 } else if let Some(rem) = input.strip_prefix("ipv6:") {
562 (RuleType::IPV6, rem)
563 } else {
564 return Err(RuleParseError::InvalidRuleType(input.to_string()));
565 };
566
567 Ok(pair)
568 }
569}
570
571#[derive(Debug, Clone, PartialEq, Eq)]
573enum RuleAction {
574 Allow,
575 Deny,
576}
577
578impl RuleAction {
579 pub fn consume_input(input: &str) -> Result<(Self, &str), RuleParseError> {
585 let pair = if let Some(rem) = input.strip_prefix("allow=") {
586 (RuleAction::Allow, rem)
587 } else if let Some(rem) = input.strip_prefix("deny=") {
588 (RuleAction::Deny, rem)
589 } else {
590 return Err(RuleParseError::InvalidRuleAction(input.to_string()));
591 };
592
593 Ok(pair)
594 }
595}
596
597#[derive(Debug, Clone, PartialEq, Eq)]
599struct RuleExpr(String);
600
601impl RuleExpr {
602 pub fn consume_input(input: &str) -> Result<(Self, &str), RuleParseError> {
609 let mut next_dns_entry = usize::MAX;
610 let mut next_ipv4_entry = usize::MAX;
611 let mut next_ipv6_entry = usize::MAX;
612
613 if let Some(idx) = input.find(",dns:") {
614 next_dns_entry = idx;
615 }
616
617 if let Some(idx) = input.find(",ipv4:") {
618 next_ipv4_entry = idx;
619 }
620
621 if let Some(idx) = input.find(",ipv6:") {
622 next_ipv6_entry = idx;
623 }
624
625 let next_entry = next_dns_entry
626 .min(next_ipv4_entry)
627 .min(next_ipv6_entry)
628 .min(input.len());
629
630 let (expr, rem) = input.split_at(next_entry);
631
632 let rem = rem.strip_prefix(',').unwrap_or(rem);
633
634 Ok((RuleExpr(expr.to_string()), rem))
635 }
636}
637
638#[derive(Debug, Clone, PartialEq, Eq)]
643struct RulesetSegment {
644 ty: RuleType,
645 action: RuleAction,
646 expr: RuleExpr,
647}
648
649fn parse_ruleset_segments(s: impl AsRef<str>) -> Result<Vec<RulesetSegment>, RuleParseError> {
650 let mut input = s.as_ref();
651 let mut segments = Vec::new();
652
653 while !input.is_empty() {
654 let (ty, remaining) = RuleType::consume_input(input)?;
655 let (action, remaining) = RuleAction::consume_input(remaining)?;
656 let (expr, remaining) = RuleExpr::consume_input(remaining)?;
657
658 segments.push(RulesetSegment { ty, action, expr });
659
660 input = remaining;
661 }
662
663 Ok(segments)
664}
665
666#[derive(Debug, Clone)]
669pub struct Ruleset {
670 rules: Arc<RwLock<Vec<Rule>>>,
671}
672
673impl Ruleset {
674 pub fn allows_socket(&self, addr: impl Into<SocketAddr>, dir: Direction) -> bool {
677 let addr = addr.into();
678
679 {
680 let ruleset = self.rules.read().unwrap();
681
682 let is_blacklisted = ruleset.iter().any(|r| r.blocks_socket(addr, dir));
683 if is_blacklisted {
684 return false;
685 }
686
687 ruleset.iter().any(|r| r.allows_socket(addr, dir))
688 }
689 }
690
691 pub fn allows_domain(&self, domain: impl AsRef<str>) -> bool {
693 let domain = domain.as_ref();
694
695 {
696 let ruleset = self.rules.read().unwrap();
697
698 let is_blacklisted = ruleset.iter().any(|r| r.blocks_domain(domain));
699 if is_blacklisted {
700 return false;
701 }
702
703 ruleset.iter().any(|r| r.allows_domain(domain))
704 }
705 }
706
707 pub fn expand_domain(
710 &self,
711 domain: impl AsRef<str>,
712 addrs: impl AsRef<[IpAddr]>,
713 ) -> Result<(), RuleParseError> {
714 let mut ruleset = self.rules.write().unwrap();
715 let domain = domain.as_ref();
716
717 let mut already_expanded = false;
718 let port_spec = ruleset
719 .iter_mut()
720 .find_map(|rule| {
721 let port_spec = rule.port_spec_of_domain(domain);
722
723 if port_spec.is_some() {
724 if rule.is_expandable() {
725 rule.set_expanded(true);
726
727 return port_spec;
728 } else {
729 already_expanded = true;
730 }
731 }
732
733 None
734 })
735 .ok_or_else(|| {
736 if already_expanded {
737 RuleParseError::DomainAlreadyExpanded(domain.to_string())
738 } else {
739 RuleParseError::DomainRuleNotFound(domain.to_string())
740 }
741 })?;
742
743 for addr in addrs.as_ref() {
744 let rule = match addr {
745 IpAddr::V4(ip) => Rule::IPV4(IPV4Rule {
746 ip_spec: IPV4Spec::IP(*ip),
747 port_spec: port_spec.clone(),
748 direction: Direction::Outbound,
749 }),
750 IpAddr::V6(ip) => Rule::IPV6(IPV6Rule {
751 ip_spec: IPV6Spec::IP(*ip),
752 port_spec: port_spec.clone(),
753 direction: Direction::Outbound,
754 }),
755 };
756
757 ruleset.push(rule);
758 }
759
760 Ok(())
761 }
762}
763
764impl FromStr for Ruleset {
765 type Err = RuleParseError;
766
767 fn from_str(s: &str) -> Result<Self, Self::Err> {
768 let s: String = s.chars().filter(|c| !c.is_whitespace()).collect();
769 let mut rules = vec![];
770 for seg in parse_ruleset_segments(s)? {
771 let rule_type = &seg.ty;
772 let rule_action = &seg.action;
773 let rule_expr = &seg.expr;
774
775 let parsed_rules: Vec<Rule> = match rule_type {
776 RuleType::Dns => parse_dns_rule(&rule_expr.0)?
777 .into_iter()
778 .map(Rule::DNS)
779 .collect(),
780 RuleType::IPV4 => parse_ipv4_rule(&rule_expr.0)?
781 .into_iter()
782 .map(Rule::IPV4)
783 .collect(),
784 RuleType::IPV6 => parse_ipv6_rule(&rule_expr.0)?
785 .into_iter()
786 .map(Rule::IPV6)
787 .collect(),
788 };
789
790 let parsed_rules = match rule_action {
791 RuleAction::Allow => parsed_rules,
792 RuleAction::Deny => parsed_rules
793 .into_iter()
794 .map(|rule| Rule::Neg(Arc::new(rule)))
795 .collect(),
796 };
797
798 rules.extend(parsed_rules);
799 }
800
801 Ok(Self {
802 rules: Arc::new(RwLock::new(rules)),
803 })
804 }
805}
806
807#[cfg(test)]
808mod tests {
809 use super::*;
810
811 #[test]
812 fn all_ports_spec() {
813 let spec = PortSpec::from_str("*").unwrap();
814
815 assert!(spec.matches(80));
816 }
817
818 #[test]
819 fn port_spec() {
820 let spec = PortSpec::from_str("80").unwrap();
821
822 assert!(spec.matches(80));
823 assert!(!spec.matches(443));
824 }
825
826 #[test]
827 fn port_range_spec() {
828 let spec = PortSpec::from_str("80-85").unwrap();
829
830 assert!(!spec.matches(79));
831 assert!(spec.matches(80));
832 assert!(spec.matches(81));
833 assert!(spec.matches(82));
834 assert!(spec.matches(83));
835 assert!(spec.matches(84));
836 assert!(spec.matches(85));
837 assert!(!spec.matches(86));
838 }
839
840 #[test]
841 fn all_domains_spec() {
842 let spec = DomainSpec::from_str("*").unwrap();
843
844 assert!(spec.matches("example.com"));
845 }
846
847 #[test]
848 fn domain_spec() {
849 let spec = DomainSpec::from_str("example.com").unwrap();
850
851 assert!(spec.matches("example.com"));
852 assert!(!spec.matches("sub.example.com"));
853 assert!(!spec.matches("test.com"));
854 }
855
856 #[test]
857 fn domain_glob_spec() {
858 let spec = DomainSpec::from_str("*.example.com").unwrap();
859
860 assert!(!spec.matches("example.com"));
861 assert!(spec.matches("sub.example.com"));
862 assert!(spec.matches("another.sub.example.com"));
863 assert!(!spec.matches("test.com"));
864 }
865
866 #[test]
867 fn all_ipv4s_spec() {
868 let spec = IPV4Spec::from_str("*").unwrap();
869
870 assert!(spec.matches([127, 0, 0, 1]));
871 }
872
873 #[test]
874 fn ipv4_spec() {
875 let spec = IPV4Spec::from_str("127.0.0.1").unwrap();
876
877 assert!(spec.matches([127, 0, 0, 1]));
878 assert!(!spec.matches([192, 168, 1, 1]));
879 }
880
881 #[test]
882 fn ipv4_range_spec() {
883 let rule = IPV4Spec::from_str("192.168.1.0/24").unwrap();
884
885 let matches = vec![
886 "192.168.1.1",
887 "192.168.1.0",
888 "192.168.1.255",
889 "192.168.1.100",
890 "192.168.1.50",
891 ];
892
893 let non_matches = vec![
894 "192.168.2.0",
895 "192.167.1.1",
896 "10.0.0.1",
897 "172.16.0.1",
898 "192.168.0.255",
899 ];
900
901 for ip in matches {
902 let ip_addr: Ipv4Addr = ip.parse().unwrap();
903 assert!(rule.matches(ip_addr));
904 }
905
906 for ip in non_matches {
907 let ip_addr: Ipv4Addr = ip.parse().unwrap();
908 assert!(!rule.matches(ip_addr));
909 }
910 }
911
912 #[test]
913 fn all_ipv6s_spec() {
914 let spec = IPV6Spec::from_str("*").unwrap();
915
916 assert!(spec.matches("2001:db8::1".parse().unwrap()));
917 }
918
919 #[test]
920 fn ipv6_spec() {
921 let spec = IPV6Spec::from_str("2001:db8::1").unwrap();
922
923 assert!(spec.matches("2001:db8::1".parse().unwrap()));
924 assert!(!spec.matches("2001:db7::1".parse().unwrap()));
925 }
926
927 #[test]
928 fn ipv6_range_spec() {
929 let spec = IPV6Spec::from_str("2001:db8::/32").unwrap();
930
931 let matches = vec![
932 "2001:db8::1",
933 "2001:db8::",
934 "2001:db8:0:0:0:0:0:1234",
935 "2001:db8::abcd",
936 "2001:db8::ffff",
937 ];
938
939 let non_matches = vec![
940 "2001:db9::",
941 "2001:db7::1",
942 "2001:dead::1",
943 "fe80::1",
944 "::1",
945 ];
946
947 for ip in matches {
948 let ip_addr: Ipv6Addr = ip.parse().unwrap();
949 assert!(spec.matches(ip_addr));
950 }
951
952 for ip in non_matches {
953 let ip_addr: Ipv6Addr = ip.parse().unwrap();
954 assert!(!spec.matches(ip_addr));
955 }
956 }
957
958 #[test]
959 fn dns_rule_all() {
960 let rules = parse_dns_rule("*:*").unwrap();
961
962 assert_eq!(rules.len(), 1);
963 assert!(rules[0].allows("example.com"));
964 assert_eq!(rules[0].allowed_ports(), PortSpec::All);
965 }
966
967 #[test]
968 fn dns_rule_single_domain_and_port() {
969 let rules = parse_dns_rule("example.com:80").unwrap();
970
971 assert_eq!(rules.len(), 1);
972 assert!(rules[0].allows("example.com"));
973 assert_eq!(rules[0].allowed_ports(), PortSpec::Port(80));
974 }
975
976 #[test]
977 fn dns_rule_multiple_domain_and_ports() {
978 let mut rules = parse_dns_rule("{a.com, *.b.com}:{80, 100-200}").unwrap();
979
980 let rule1 = rules.pop().unwrap(); let rule2 = rules.pop().unwrap(); let rule3 = rules.pop().unwrap(); let rule4 = rules.pop().unwrap(); assert!(rules.is_empty());
986
987 assert!(rule1.allows("sub.b.com"));
988 assert!(!rule1.allows("b.com"));
989 assert!(!rule1.allows("a.com"));
990 assert_eq!(rule1.allowed_ports(), PortSpec::PortRange(100..=200));
991
992 assert!(rule2.allows("sub.b.com"));
993 assert!(!rule2.allows("b.com"));
994 assert!(!rule2.allows("a.com"));
995 assert_eq!(rule2.allowed_ports(), PortSpec::Port(80));
996
997 assert!(rule3.allows("a.com"));
998 assert!(!rule3.allows("sub.a.com"));
999 assert!(!rule3.allows("b.com"));
1000 assert_eq!(rule3.allowed_ports(), PortSpec::PortRange(100..=200));
1001
1002 assert!(rule4.allows("a.com"));
1003 assert!(!rule4.allows("sub.a.com"));
1004 assert!(!rule4.allows("b.com"));
1005 assert_eq!(rule4.allowed_ports(), PortSpec::Port(80));
1006 }
1007
1008 #[test]
1009 fn ipv4_rule_all() {
1010 let rules = parse_ipv4_rule("*:*").unwrap();
1011
1012 assert_eq!(rules.len(), 1);
1013 assert!(rules[0].is_allowed([127, 0, 0, 1], 80, Direction::Inbound));
1014 assert!(rules[0].is_allowed([127, 0, 0, 1], 80, Direction::Outbound));
1015 }
1016
1017 #[test]
1018 fn ipv4_rule_single_ip_all_ports_inbound() {
1019 let rules = parse_ipv4_rule("127.0.0.1:*/in").unwrap();
1020
1021 assert_eq!(rules.len(), 1);
1022 assert!(rules[0].is_allowed([127, 0, 0, 1], 80, Direction::Inbound));
1023 assert!(!rules[0].is_allowed([127, 0, 0, 1], 80, Direction::Outbound));
1024 assert!(!rules[0].is_allowed([192, 168, 1, 2], 80, Direction::Inbound));
1025 assert!(!rules[0].is_allowed([192, 168, 1, 2], 80, Direction::Outbound));
1026 }
1027
1028 #[test]
1029 fn ipv4_rule_ip_range_all_ports_outbound() {
1030 let mut rules = parse_ipv4_rule("192.168.1.0/24:*/out").unwrap();
1031
1032 let ip_matches = vec![
1033 "192.168.1.1",
1034 "192.168.1.0",
1035 "192.168.1.255",
1036 "192.168.1.100",
1037 "192.168.1.50",
1038 ];
1039
1040 let ip_non_matches = vec![
1041 "192.168.2.0",
1042 "192.167.1.1",
1043 "10.0.0.1",
1044 "172.16.0.1",
1045 "192.168.0.255",
1046 ];
1047
1048 assert_eq!(rules.len(), 1);
1049 let rule = rules.pop().unwrap();
1050
1051 for ip in &ip_matches {
1052 let ip_addr: Ipv4Addr = ip.parse().unwrap();
1053 assert!(rule.is_allowed(ip_addr, 8080, Direction::Outbound));
1054 }
1055 for ip in &ip_matches {
1057 let ip_addr: Ipv4Addr = ip.parse().unwrap();
1058 assert!(!rule.is_allowed(ip_addr, 8080, Direction::Inbound));
1059 }
1060 for ip in &ip_non_matches {
1062 let ip_addr: Ipv4Addr = ip.parse().unwrap();
1063 assert!(!rule.is_allowed(ip_addr, 8080, Direction::Inbound));
1064 }
1065 }
1066
1067 #[test]
1068 fn ipv4_rule_all_ip_port_range_outbound() {
1069 let rules = parse_ipv4_rule("*:80-90/out").unwrap();
1070
1071 assert_eq!(rules.len(), 1);
1072 assert!(!rules[0].is_allowed([127, 0, 0, 1], 80, Direction::Inbound));
1073 assert!(rules[0].is_allowed([127, 0, 0, 1], 80, Direction::Outbound));
1074 assert!(rules[0].is_allowed([127, 0, 0, 1], 85, Direction::Outbound));
1075 assert!(rules[0].is_allowed([127, 0, 0, 1], 90, Direction::Outbound));
1076 assert!(!rules[0].is_allowed([127, 0, 0, 1], 443, Direction::Outbound));
1077 assert!(!rules[0].is_allowed([192, 168, 1, 2], 80, Direction::Inbound));
1078 assert!(rules[0].is_allowed([192, 168, 1, 2], 80, Direction::Outbound));
1079 }
1080
1081 #[test]
1082 fn multiple_ipv4_rules() {
1083 let mut rules = parse_ipv4_rule("{127.0.0.1, 192.168.1.0/24}:{80, 8080}/in").unwrap();
1084
1085 let rule1 = rules.pop().unwrap(); let rule2 = rules.pop().unwrap(); let rule3 = rules.pop().unwrap(); let rule4 = rules.pop().unwrap(); assert!(rules.is_empty());
1091
1092 let ip_matches = vec![
1093 "192.168.1.1",
1094 "192.168.1.0",
1095 "192.168.1.255",
1096 "192.168.1.100",
1097 "192.168.1.50",
1098 ];
1099
1100 let ip_non_matches = vec![
1101 "192.168.2.0",
1102 "192.167.1.1",
1103 "10.0.0.1",
1104 "172.16.0.1",
1105 "192.168.0.255",
1106 ];
1107
1108 for ip in &ip_matches {
1110 let ip_addr: Ipv4Addr = ip.parse().unwrap();
1111 assert!(rule1.is_allowed(ip_addr, 8080, Direction::Inbound));
1112 }
1113 for ip in &ip_matches {
1115 let ip_addr: Ipv4Addr = ip.parse().unwrap();
1116 assert!(!rule1.is_allowed(ip_addr, 8080, Direction::Outbound));
1117 }
1118 for ip in &ip_matches {
1120 let ip_addr: Ipv4Addr = ip.parse().unwrap();
1121 assert!(!rule1.is_allowed(ip_addr, 80, Direction::Inbound));
1122 }
1123 for ip in &ip_non_matches {
1125 let ip_addr: Ipv4Addr = ip.parse().unwrap();
1126 assert!(!rule1.is_allowed(ip_addr, 8080, Direction::Inbound));
1127 }
1128
1129 for ip in &ip_matches {
1131 let ip_addr: Ipv4Addr = ip.parse().unwrap();
1132 assert!(rule2.is_allowed(ip_addr, 80, Direction::Inbound));
1133 }
1134 for ip in &ip_matches {
1136 let ip_addr: Ipv4Addr = ip.parse().unwrap();
1137 assert!(!rule2.is_allowed(ip_addr, 80, Direction::Outbound));
1138 }
1139 for ip in &ip_matches {
1141 let ip_addr: Ipv4Addr = ip.parse().unwrap();
1142 assert!(!rule2.is_allowed(ip_addr, 8080, Direction::Inbound));
1143 }
1144 for ip in &ip_non_matches {
1146 let ip_addr: Ipv4Addr = ip.parse().unwrap();
1147 assert!(!rule2.is_allowed(ip_addr, 80, Direction::Inbound));
1148 }
1149
1150 assert!(rule3.is_allowed([127, 0, 0, 1], 8080, Direction::Inbound));
1152 assert!(!rule3.is_allowed([192, 168, 1, 2], 8080, Direction::Inbound));
1153 assert!(!rule3.is_allowed([127, 0, 0, 1], 80, Direction::Inbound));
1154 assert!(!rule3.is_allowed([127, 0, 0, 1], 8080, Direction::Outbound));
1155
1156 assert!(rule4.is_allowed([127, 0, 0, 1], 80, Direction::Inbound));
1158 assert!(!rule4.is_allowed([192, 168, 1, 2], 80, Direction::Inbound));
1159 assert!(!rule4.is_allowed([127, 0, 0, 1], 8080, Direction::Inbound));
1160 assert!(!rule4.is_allowed([127, 0, 0, 1], 80, Direction::Outbound));
1161 }
1162
1163 #[test]
1164 fn ipv6_rule_all() {
1165 let rules = parse_ipv6_rule("*:*").unwrap();
1166
1167 assert_eq!(rules.len(), 1);
1168 assert!(rules[0].is_allowed(
1169 "2001:db8::1".parse::<Ipv6Addr>().unwrap(),
1170 80,
1171 Direction::Inbound
1172 ));
1173 assert!(rules[0].is_allowed(
1174 "2001:db8::1".parse::<Ipv6Addr>().unwrap(),
1175 80,
1176 Direction::Outbound
1177 ));
1178 }
1179
1180 #[test]
1181 fn ipv6_rule_single_ip_and_port() {
1182 let rules = parse_ipv6_rule("[2001:db8::1]:80").unwrap();
1183
1184 assert_eq!(rules.len(), 1);
1185 assert!(rules[0].is_allowed(
1186 "2001:db8::1".parse::<Ipv6Addr>().unwrap(),
1187 80,
1188 Direction::Inbound
1189 ));
1190 assert!(rules[0].is_allowed(
1191 "2001:db8::1".parse::<Ipv6Addr>().unwrap(),
1192 80,
1193 Direction::Outbound
1194 ));
1195 }
1196
1197 #[test]
1198 fn ipv6_rule_single_ip_all_ports_inbound() {
1199 let rules = parse_ipv6_rule("[2001:db8::1]:*/in").unwrap();
1200
1201 assert_eq!(rules.len(), 1);
1202 assert!(rules[0].is_allowed(
1203 "2001:db8::1".parse::<Ipv6Addr>().unwrap(),
1204 80,
1205 Direction::Inbound
1206 ));
1207 assert!(!rules[0].is_allowed(
1208 "2002:db8::1".parse::<Ipv6Addr>().unwrap(),
1209 80,
1210 Direction::Inbound
1211 ));
1212 assert!(!rules[0].is_allowed(
1213 "2001:db8::1".parse::<Ipv6Addr>().unwrap(),
1214 8080,
1215 Direction::Outbound
1216 ));
1217 }
1218
1219 #[test]
1220 fn ipv6_rule_ip_range_all_ports_outbound() {
1221 let mut rules = parse_ipv6_rule("[2001:db8::/32]:*/out").unwrap();
1222
1223 let ip_matches = vec![
1224 "2001:db8::1",
1225 "2001:db8::",
1226 "2001:db8:0:0:0:0:0:1234",
1227 "2001:db8::abcd",
1228 "2001:db8::ffff",
1229 ];
1230
1231 let ip_non_matches = vec![
1232 "2001:db9::",
1233 "2001:db7::1",
1234 "2001:dead::1",
1235 "fe80::1",
1236 "::1",
1237 ];
1238
1239 assert_eq!(rules.len(), 1);
1240 let rule = rules.pop().unwrap();
1241
1242 for ip in &ip_matches {
1243 let ip_addr: Ipv6Addr = ip.parse().unwrap();
1244 assert!(rule.is_allowed(ip_addr, 8080, Direction::Outbound));
1245 }
1246 for ip in &ip_matches {
1248 let ip_addr: Ipv6Addr = ip.parse().unwrap();
1249 assert!(!rule.is_allowed(ip_addr, 8080, Direction::Inbound));
1250 }
1251 for ip in &ip_non_matches {
1253 let ip_addr: Ipv6Addr = ip.parse().unwrap();
1254 assert!(!rule.is_allowed(ip_addr, 8080, Direction::Inbound));
1255 }
1256 }
1257
1258 #[test]
1259 fn multiple_ipv6_rules() {
1260 let mut rules = parse_ipv6_rule("{3001:db8::, 2001:db8::/32}:{80, 8080}/in").unwrap();
1261
1262 let rule1 = rules.pop().unwrap(); let rule2 = rules.pop().unwrap(); let rule3 = rules.pop().unwrap(); let rule4 = rules.pop().unwrap(); assert!(rules.is_empty());
1268
1269 let ip_matches = vec![
1270 "2001:db8::1",
1271 "2001:db8::",
1272 "2001:db8:0:0:0:0:0:1234",
1273 "2001:db8::abcd",
1274 "2001:db8::ffff",
1275 ];
1276
1277 let ip_non_matches = vec![
1278 "2001:db9::",
1279 "2001:db7::1",
1280 "2001:dead::1",
1281 "fe80::1",
1282 "::1",
1283 ];
1284
1285 for ip in &ip_matches {
1287 let ip_addr: Ipv6Addr = ip.parse().unwrap();
1288 assert!(rule1.is_allowed(ip_addr, 8080, Direction::Inbound));
1289 }
1290 for ip in &ip_matches {
1292 let ip_addr: Ipv6Addr = ip.parse().unwrap();
1293 assert!(!rule1.is_allowed(ip_addr, 8080, Direction::Outbound));
1294 }
1295 for ip in &ip_matches {
1297 let ip_addr: Ipv6Addr = ip.parse().unwrap();
1298 assert!(!rule1.is_allowed(ip_addr, 80, Direction::Inbound));
1299 }
1300 for ip in &ip_non_matches {
1302 let ip_addr: Ipv6Addr = ip.parse().unwrap();
1303 assert!(!rule1.is_allowed(ip_addr, 8080, Direction::Inbound));
1304 }
1305
1306 for ip in &ip_matches {
1308 let ip_addr: Ipv6Addr = ip.parse().unwrap();
1309 assert!(rule2.is_allowed(ip_addr, 80, Direction::Inbound));
1310 }
1311 for ip in &ip_matches {
1313 let ip_addr: Ipv6Addr = ip.parse().unwrap();
1314 assert!(!rule2.is_allowed(ip_addr, 80, Direction::Outbound));
1315 }
1316 for ip in &ip_matches {
1318 let ip_addr: Ipv6Addr = ip.parse().unwrap();
1319 assert!(!rule2.is_allowed(ip_addr, 8080, Direction::Inbound));
1320 }
1321 for ip in &ip_non_matches {
1323 let ip_addr: Ipv6Addr = ip.parse().unwrap();
1324 assert!(!rule2.is_allowed(ip_addr, 80, Direction::Inbound));
1325 }
1326
1327 assert!(rule3.is_allowed(
1329 "3001:db8::".parse::<Ipv6Addr>().unwrap(),
1330 8080,
1331 Direction::Inbound
1332 ));
1333 assert!(!rule3.is_allowed(
1334 "4001:db8::".parse::<Ipv6Addr>().unwrap(),
1335 8080,
1336 Direction::Inbound
1337 ));
1338 assert!(!rule3.is_allowed(
1339 "3001:db8::".parse::<Ipv6Addr>().unwrap(),
1340 80,
1341 Direction::Inbound
1342 ));
1343 assert!(!rule3.is_allowed(
1344 "3001:db8::".parse::<Ipv6Addr>().unwrap(),
1345 8080,
1346 Direction::Outbound
1347 ));
1348
1349 assert!(rule4.is_allowed(
1351 "3001:db8::".parse::<Ipv6Addr>().unwrap(),
1352 80,
1353 Direction::Inbound
1354 ));
1355 assert!(!rule4.is_allowed(
1356 "4001:db8::".parse::<Ipv6Addr>().unwrap(),
1357 80,
1358 Direction::Inbound
1359 ));
1360 assert!(!rule4.is_allowed(
1361 "3001:db8::".parse::<Ipv6Addr>().unwrap(),
1362 8080,
1363 Direction::Inbound
1364 ));
1365 assert!(!rule4.is_allowed(
1366 "3001:db8::".parse::<Ipv6Addr>().unwrap(),
1367 80,
1368 Direction::Outbound
1369 ));
1370 }
1371
1372 #[test]
1373 fn ruleset_dns() {
1374 let ruleset = Ruleset::from_str("dns:allow={a.com, *.b.com}:{80, 8080}").unwrap();
1375
1376 assert!(ruleset.allows_domain("a.com"));
1377 assert!(!ruleset.allows_domain("sub.a.com"));
1378 assert!(!ruleset.allows_domain("b.com"));
1379 assert!(ruleset.allows_domain("sub.b.com"));
1380 assert!(ruleset.allows_domain("another.sub.b.com"));
1381 }
1382
1383 #[test]
1384 fn ruleset_ipv4() {
1385 let ruleset =
1386 Ruleset::from_str("ipv4:deny={127.0.0.1, 192.168.1.0/24}:{80, 8080}/in").unwrap();
1387
1388 let ip_matches = vec![
1389 "192.168.1.1",
1390 "192.168.1.0",
1391 "192.168.1.255",
1392 "192.168.1.100",
1393 "192.168.1.50",
1394 ];
1395
1396 for ip in &ip_matches {
1397 let ip_addr: Ipv4Addr = ip.parse().unwrap();
1398 assert!(!ruleset.allows_socket((ip_addr, 8080), Direction::Inbound));
1399 }
1400
1401 assert!(!ruleset.allows_socket(([127, 0, 0, 1], 8080), Direction::Inbound));
1402 assert!(!ruleset.allows_socket(([127, 0, 0, 1], 80), Direction::Inbound));
1403 }
1404
1405 #[test]
1406 fn ruleset_ipv6() {
1407 let ruleset =
1408 Ruleset::from_str("ipv6:allow={3001:db8::, 2001:db8::/32}:{80, 8080}/in").unwrap();
1409
1410 let ip_matches = vec![
1411 "2001:db8::1",
1412 "2001:db8::",
1413 "2001:db8:0:0:0:0:0:1234",
1414 "2001:db8::abcd",
1415 "2001:db8::ffff",
1416 ];
1417
1418 for ip in &ip_matches {
1419 let ip_addr: Ipv6Addr = ip.parse().unwrap();
1420 assert!(ruleset.allows_socket((ip_addr, 8080), Direction::Inbound));
1421 }
1422
1423 assert!(ruleset.allows_socket(
1424 ("3001:db8::".parse::<Ipv6Addr>().unwrap(), 8080),
1425 Direction::Inbound
1426 ));
1427 assert!(ruleset.allows_socket(
1428 ("3001:db8::".parse::<Ipv6Addr>().unwrap(), 8080),
1429 Direction::Inbound
1430 ));
1431 }
1432
1433 #[test]
1434 fn ruleset_full() {
1435 let ruleset = Ruleset::from_str(
1436 "dns:allow={a.com, *.b.com}:{80, 8080},
1437 ipv4:deny={127.0.0.1, 192.168.1.0/24}:{80, 8080}/in,
1438 ipv6:allow={3001:db8::, 2001:db8::/32}:{80, 8080}/in",
1439 )
1440 .unwrap();
1441
1442 assert!(ruleset.allows_domain("a.com"));
1444 assert!(!ruleset.allows_domain("sub.a.com"));
1445 assert!(!ruleset.allows_domain("b.com"));
1446 assert!(ruleset.allows_domain("sub.b.com"));
1447 assert!(ruleset.allows_domain("another.sub.b.com"));
1448
1449 let ip_matches = vec![
1451 "192.168.1.1",
1452 "192.168.1.0",
1453 "192.168.1.255",
1454 "192.168.1.100",
1455 "192.168.1.50",
1456 ];
1457
1458 for ip in &ip_matches {
1459 let ip_addr: Ipv4Addr = ip.parse().unwrap();
1460 assert!(!ruleset.allows_socket((ip_addr, 8080), Direction::Inbound));
1461 }
1462
1463 assert!(!ruleset.allows_socket(([127, 0, 0, 1], 8080), Direction::Inbound));
1464 assert!(!ruleset.allows_socket(([127, 0, 0, 1], 80), Direction::Inbound));
1465
1466 let ip_matches = vec![
1468 "2001:db8::1",
1469 "2001:db8::",
1470 "2001:db8:0:0:0:0:0:1234",
1471 "2001:db8::abcd",
1472 "2001:db8::ffff",
1473 ];
1474
1475 for ip in &ip_matches {
1476 let ip_addr: Ipv6Addr = ip.parse().unwrap();
1477 assert!(ruleset.allows_socket((ip_addr, 8080), Direction::Inbound));
1478 }
1479
1480 assert!(ruleset.allows_socket(
1481 ("3001:db8::".parse::<Ipv6Addr>().unwrap(), 8080),
1482 Direction::Inbound
1483 ));
1484 assert!(ruleset.allows_socket(
1485 ("3001:db8::".parse::<Ipv6Addr>().unwrap(), 8080),
1486 Direction::Inbound
1487 ));
1488 }
1489}