virtual_net/
ruleset.rs

1/// A [`Ruleset`] can be used to specify a whitelist and a blacklist in order to
2/// control the inbound and outbound traffic of a network.
3///
4/// ## Rule Specification
5/// Each rule can be expressed like:
6/// ```text
7/// <rule_kind>:<rule_action>=<rule_expr>
8///
9/// <rule_kind>: dns, ipv4, ipv6
10///
11/// <rule_action>: allow | deny
12///
13/// dns:
14/// <rule_expr>:
15/// {<domain_spec>}:{<port_spec>} (this will be expanded to an outbound IP rule)
16/// <domain_spec>: domain | domain glob | *
17///
18/// ipv4:
19/// <rule_expr>:
20/// <ipv4_specs>:<port_specs>/<in|out>
21/// <ipv4_specs>: <ipv4_spec> | {<ipv4_spec>,}
22/// <ipv4_spec>: ipv4 | ipv4_range | *
23///
24/// ipv6:
25/// <rule_expr>:
26/// <ipv6_specs>:<port_specs>/<in|out>
27/// <ipv6_specs>: <ipv6_spec> | {<ipv6_spec>,}
28/// <ipv6_spec>: ipv6 | ipv6_range | *
29///
30/// <port_specs>: <port_spec> | {<port_specs>,}
31/// <port_spec>: port | start_port-end_port | *
32/// ```
33///
34/// The current implementation supports:
35///
36/// ### Whitelisting and Blacklisting
37/// Each rule can be expressed as an `allow` (whitelist) or `deny` (blacklist). A socket or domain
38/// is only accessible if at least one rule whitelists it and no rule blacklists it.
39///
40/// ### Directional Filtering
41/// IP based rules can be either directional by specifying `/in` or `/out` postfixes to the rule,
42/// or bidirectional which is the default setting for these rules.
43///
44/// ### Rule Combination
45/// In order to prevent repetition, the parts before and after the `:` could hold multiple values.
46/// For example:
47/// ```text
48/// ipv4:deny={127.0.0.1/24, 192.168.1.1/24}:{80, 443}
49/// ```
50/// This is equivalent to:
51/// ```text
52/// ipv4:deny=127.0.0.1/24:80,
53/// ipv4:deny=127.0.0.1/24:443,
54/// ipv4:deny=192.168.1.1/24:80,
55/// ipv4:deny=192.168.1.1/24:443
56/// ```
57use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
58use std::ops::RangeInclusive;
59use std::str::FromStr;
60use std::sync::{Arc, RwLock};
61
62use ipnet::{Ipv4Net, Ipv6Net};
63use iprange::IpRange;
64
65/// Represents the errors that could happen during parsing the ruleset
66#[derive(Debug, thiserror::Error)]
67pub enum RuleParseError {
68    #[error("invalid connection direction: {0}")]
69    Direction(String),
70    #[error("failed to parse int: {0}")]
71    InvalidInteger(#[from] std::num::ParseIntError),
72    #[error("failed to parse IP range address: {0}")]
73    InvalidIpRange(#[from] ipnet::AddrParseError),
74    #[error("failed to parse IP address: {0}")]
75    InvalidIpAddr(#[from] std::net::AddrParseError),
76    #[error("missing colon in rule: {0}")]
77    MissingColon(String),
78    #[error("Single IPV6 entry is not enclosed in brackets: {0}")]
79    MalformedIpv6(String),
80    #[error("Invalid rule type: {0}. Rule type must be either dns, ipv4, or ipv6")]
81    InvalidRuleType(String),
82    #[error("Invalid rule action: {0}. Rule action must be either allow or deny")]
83    InvalidRuleAction(String),
84    #[error("Domain rule not found for: {0}")]
85    DomainRuleNotFound(String),
86    #[error("Domain rule already expanded: {0}")]
87    DomainAlreadyExpanded(String),
88}
89
90/// Represents the direction of the network traffic
91#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92pub enum Direction {
93    Inbound,
94    Outbound,
95    Bidirectional,
96}
97
98impl Direction {
99    pub fn matches(&self, direction: Direction) -> bool {
100        *self == Direction::Bidirectional || *self == direction
101    }
102}
103
104impl FromStr for Direction {
105    type Err = RuleParseError;
106
107    fn from_str(s: &str) -> Result<Self, Self::Err> {
108        let direction = if s == "in" {
109            Direction::Inbound
110        } else if s == "out" {
111            Direction::Outbound
112        } else {
113            return Err(RuleParseError::Direction(s.to_string()));
114        };
115
116        Ok(direction)
117    }
118}
119
120/// Specification of a port rule
121#[derive(Debug, Clone, PartialEq, Eq)]
122pub enum PortSpec {
123    /// All ports are allowed
124    All,
125    /// Allows a single port
126    Port(u16),
127    /// Allows a range of ports
128    PortRange(RangeInclusive<u16>),
129}
130
131impl PortSpec {
132    pub fn matches(&self, port: u16) -> bool {
133        match self {
134            PortSpec::All => true,
135            PortSpec::Port(allowed_port) => *allowed_port == port,
136            PortSpec::PortRange(allowed_port_range) => allowed_port_range.contains(&port),
137        }
138    }
139}
140
141impl FromStr for PortSpec {
142    type Err = RuleParseError;
143
144    fn from_str(s: &str) -> Result<Self, Self::Err> {
145        let rule = if s == "*" {
146            PortSpec::All
147        } else if s.contains('-') {
148            let (start, end) = s.split_once('-').unwrap();
149
150            let (start, end) = (start.parse()?, end.parse()?);
151
152            PortSpec::PortRange(start..=end)
153        } else {
154            PortSpec::Port(s.parse()?)
155        };
156
157        Ok(rule)
158    }
159}
160
161/// Specification of a domain
162#[derive(Debug, Clone, PartialEq, Eq)]
163pub enum DomainSpec {
164    /// All domains
165    All,
166    /// A single domain like: example.com
167    Domain(String),
168    /// A domain glob like: *.example.com
169    DomainGlob(String),
170}
171
172impl DomainSpec {
173    pub fn matches(&self, domain: impl AsRef<str>) -> bool {
174        let domain = domain.as_ref();
175
176        match self {
177            DomainSpec::All => true,
178            DomainSpec::Domain(allowed_domain) => allowed_domain == domain,
179            DomainSpec::DomainGlob(domain_glob) => domain.ends_with(domain_glob),
180        }
181    }
182}
183
184impl FromStr for DomainSpec {
185    type Err = RuleParseError;
186
187    fn from_str(s: &str) -> Result<Self, Self::Err> {
188        let spec = if s == "*" {
189            DomainSpec::All
190        } else if let Some(glob) = s.strip_prefix('*') {
191            DomainSpec::DomainGlob(glob.to_string())
192        } else {
193            DomainSpec::Domain(s.to_string())
194        };
195
196        Ok(spec)
197    }
198}
199
200/// Represents a DNS rule
201#[derive(Debug, Clone, PartialEq, Eq)]
202pub struct DNSRule {
203    // The allowed domain
204    domain: DomainSpec,
205    // The allowed port
206    port: PortSpec,
207    // Indicates whether this rule has been expanded into
208    // a list of IP and port based rules
209    expanded: bool,
210}
211
212impl DNSRule {
213    /// Returns `true` if the `domain` is allowed by this rule
214    pub fn allows(&self, domain: impl AsRef<str>) -> bool {
215        self.domain.matches(domain)
216    }
217
218    /// Returns the allowed ports on the domains allowed by this rule
219    pub fn allowed_ports(&self) -> PortSpec {
220        self.port.clone()
221    }
222}
223
224/// Specification of an Ipv4
225#[derive(Debug, Clone, PartialEq, Eq)]
226pub enum IPV4Spec {
227    /// All IPs
228    All,
229    /// A single IP
230    IP(Ipv4Addr),
231    /// An IP range in the format of `ip/mask`
232    IPRange(IpRange<Ipv4Net>),
233}
234
235impl IPV4Spec {
236    pub fn matches(&self, ip: impl Into<Ipv4Addr>) -> bool {
237        let ip = ip.into();
238
239        match self {
240            IPV4Spec::All => true,
241            IPV4Spec::IP(allowed_ip) => *allowed_ip == ip,
242            IPV4Spec::IPRange(allowed_ip_range) => allowed_ip_range.contains(&ip),
243        }
244    }
245}
246
247impl FromStr for IPV4Spec {
248    type Err = RuleParseError;
249
250    fn from_str(s: &str) -> Result<Self, Self::Err> {
251        let spec = if s == "*" {
252            IPV4Spec::All
253        } else if s.contains('/') {
254            let ip = Ipv4Net::from_str(s)?;
255            let mut ip_range = IpRange::<Ipv4Net>::new();
256            ip_range.add(ip);
257
258            IPV4Spec::IPRange(ip_range)
259        } else {
260            IPV4Spec::IP(s.parse()?)
261        };
262
263        Ok(spec)
264    }
265}
266
267/// Represents an Ipv4 rule
268#[derive(Debug, Clone, PartialEq, Eq)]
269pub struct IPV4Rule {
270    // Allowed IPs
271    ip_spec: IPV4Spec,
272    // Allowed ports
273    port_spec: PortSpec,
274    // Allowed direction of the traffic
275    direction: Direction,
276}
277
278impl IPV4Rule {
279    pub fn is_allowed(&self, ip: impl Into<Ipv4Addr>, port: u16, dir: Direction) -> bool {
280        let ip = ip.into();
281
282        self.ip_spec.matches(ip) && self.port_spec.matches(port) && self.direction.matches(dir)
283    }
284}
285
286/// Specification of an Ipv6 address
287#[derive(Debug, Clone, PartialEq, Eq)]
288pub enum IPV6Spec {
289    /// All IPs
290    All,
291    /// Single IP
292    IP(Ipv6Addr),
293    /// An IP range in the format of `ip/mask`
294    IPRange(IpRange<Ipv6Net>),
295}
296
297impl IPV6Spec {
298    pub fn matches(&self, ip: Ipv6Addr) -> bool {
299        match self {
300            IPV6Spec::All => true,
301            IPV6Spec::IP(allowed_ip) => *allowed_ip == ip,
302            IPV6Spec::IPRange(allowed_ip_range) => allowed_ip_range.contains(&ip),
303        }
304    }
305}
306
307impl FromStr for IPV6Spec {
308    type Err = RuleParseError;
309
310    fn from_str(s: &str) -> Result<Self, Self::Err> {
311        let spec = if s == "*" {
312            IPV6Spec::All
313        } else if s.contains('/') {
314            let ip = Ipv6Net::from_str(s)?;
315            let mut ip_range = IpRange::<Ipv6Net>::new();
316            ip_range.add(ip);
317
318            IPV6Spec::IPRange(ip_range)
319        } else {
320            IPV6Spec::IP(s.parse()?)
321        };
322
323        Ok(spec)
324    }
325}
326
327/// Represents an Ipv6 rule
328#[derive(Debug, Clone, PartialEq, Eq)]
329pub struct IPV6Rule {
330    // Allowed IPs
331    ip_spec: IPV6Spec,
332    // Allowed ports
333    port_spec: PortSpec,
334    // Allowed direction of the traffic
335    direction: Direction,
336}
337
338impl IPV6Rule {
339    pub fn is_allowed(&self, ip: impl Into<Ipv6Addr>, port: u16, dir: Direction) -> bool {
340        let ip = ip.into();
341
342        self.ip_spec.matches(ip) && self.port_spec.matches(port) && self.direction.matches(dir)
343    }
344}
345
346/// Represents all supported rules
347#[derive(Debug, Clone, PartialEq, Eq)]
348pub enum Rule {
349    /// Allowed IPv4 traffic
350    IPV4(IPV4Rule),
351    /// Allowed IPv6 traffic
352    IPV6(IPV6Rule),
353    /// Allowed DNS queries
354    DNS(DNSRule),
355    /// Negative of a rule
356    Neg(Arc<Rule>),
357}
358
359impl Rule {
360    /// Returns `true` if this rule allows accessing `socket_addr` in the specific `direction`
361    pub fn allows_socket(&self, socket_addr: SocketAddr, direction: Direction) -> bool {
362        let ip = socket_addr.ip();
363        let port = socket_addr.port();
364
365        match (self, ip) {
366            (Rule::IPV4(rule), IpAddr::V4(ip)) => rule.is_allowed(ip, port, direction),
367            (Rule::IPV6(rule), IpAddr::V6(ip)) => rule.is_allowed(ip, port, direction),
368            _ => false,
369        }
370    }
371
372    /// Returns `true` if this rule allows querying the specific `domain`
373    pub fn allows_domain(&self, domain: impl AsRef<str>) -> bool {
374        if let Rule::DNS(rule) = self {
375            rule.allows(domain)
376        } else {
377            false
378        }
379    }
380
381    /// Returns `true` if this rule blocks accessing `socket_addr` in the specific `direction`
382    pub fn blocks_socket(&self, socket_addr: SocketAddr, direction: Direction) -> bool {
383        if let Rule::Neg(rule) = self {
384            rule.allows_socket(socket_addr, direction)
385        } else {
386            false
387        }
388    }
389
390    /// Returns `true` if this rule blocks querying the specific `domain`
391    pub fn blocks_domain(&self, domain: impl AsRef<str>) -> bool {
392        if let Rule::Neg(rule) = self {
393            rule.allows_domain(domain)
394        } else {
395            false
396        }
397    }
398
399    /// Returns allowed ports for the specified `domain` if this rule is a DNS rule
400    pub fn port_spec_of_domain(&mut self, domain: impl AsRef<str>) -> Option<PortSpec> {
401        if let Rule::DNS(rule) = self {
402            if rule.allows(domain) {
403                return Some(rule.allowed_ports());
404            }
405        }
406
407        None
408    }
409
410    /// Returns `true` if this rule is a DNS rule and has not been expanded yet
411    pub fn is_expandable(&self) -> bool {
412        if let Rule::DNS(rule) = self {
413            !rule.expanded
414        } else {
415            false
416        }
417    }
418
419    /// Sets the expanded state of this rule if its a DNS rule
420    pub fn set_expanded(&mut self, expanded: bool) {
421        if let Rule::DNS(rule) = self {
422            rule.expanded = expanded;
423        }
424    }
425}
426
427fn parse_enclosed(s: &str, left: char, right: char) -> Option<&str> {
428    match (s.find(left), s.rfind(right)) {
429        (Some(left_idx), Some(right_idx)) if left_idx < right_idx => {
430            Some(&s[left_idx + 1..right_idx])
431        }
432        _ => None,
433    }
434}
435
436fn parse_as_list<T: FromStr<Err = RuleParseError>>(s: &str) -> Result<Vec<T>, RuleParseError> {
437    let entries = if let Some(entries) = parse_enclosed(s, '{', '}') {
438        entries
439            .split(',')
440            .map(|s| s.trim().parse())
441            .collect::<Result<Vec<_>, _>>()?
442    } else {
443        let entry = T::from_str(s)?;
444
445        vec![entry]
446    };
447
448    Ok(entries)
449}
450
451fn parse_ipv4_rule(s: &str) -> Result<Vec<IPV4Rule>, RuleParseError> {
452    let (ips, ports_and_direction) = s
453        .split_once(':')
454        .ok_or_else(|| RuleParseError::MissingColon(s.to_string()))?;
455
456    let mut direction = Direction::Bidirectional;
457    let ports = if let Some((ports, dir)) = ports_and_direction.split_once('/') {
458        direction = dir.parse()?;
459
460        ports
461    } else {
462        ports_and_direction
463    };
464
465    let mut rules = Vec::new();
466    let ips = parse_as_list::<IPV4Spec>(ips)?;
467    let ports = parse_as_list::<PortSpec>(ports)?;
468
469    for ip in &ips {
470        for port in &ports {
471            rules.push(IPV4Rule {
472                ip_spec: ip.clone(),
473                port_spec: port.clone(),
474                direction,
475            });
476        }
477    }
478
479    Ok(rules)
480}
481
482fn parse_ipv6_rule(s: &str) -> Result<Vec<IPV6Rule>, RuleParseError> {
483    let (ips, ports_and_direction) = s
484        .rsplit_once(':')
485        .ok_or_else(|| RuleParseError::MissingColon(s.to_string()))?;
486
487    let mut direction = Direction::Bidirectional;
488    let ports = if let Some((ports, dir)) = ports_and_direction.split_once('/') {
489        direction = dir.parse()?;
490
491        ports
492    } else {
493        ports_and_direction
494    };
495
496    let mut rules = Vec::new();
497
498    let ips = if ips.contains('[') {
499        let ip = parse_enclosed(ips, '[', ']')
500            .ok_or_else(|| RuleParseError::MalformedIpv6(ips.to_string()))?;
501
502        vec![ip.parse::<IPV6Spec>()?]
503    } else {
504        parse_as_list::<IPV6Spec>(ips)?
505    };
506    let ports = parse_as_list::<PortSpec>(ports)?;
507
508    for ip in &ips {
509        for port in &ports {
510            rules.push(IPV6Rule {
511                ip_spec: ip.clone(),
512                port_spec: port.clone(),
513                direction,
514            });
515        }
516    }
517
518    Ok(rules)
519}
520
521fn parse_dns_rule(s: &str) -> Result<Vec<DNSRule>, RuleParseError> {
522    let (domains, ports) = s
523        .split_once(':')
524        .ok_or_else(|| RuleParseError::MissingColon(s.to_string()))?;
525
526    let mut rules = Vec::new();
527    let domains = parse_as_list::<DomainSpec>(domains)?;
528    let ports = parse_as_list::<PortSpec>(ports)?;
529
530    for domain in &domains {
531        for port in &ports {
532            rules.push(DNSRule {
533                domain: domain.clone(),
534                port: port.clone(),
535                expanded: false,
536            });
537        }
538    }
539
540    Ok(rules)
541}
542
543// Represents the rule type section in a rule segment
544#[derive(Debug, Clone, PartialEq, Eq)]
545enum RuleType {
546    Dns,
547    IPV4,
548    IPV6,
549}
550
551impl RuleType {
552    // Receives a string as input and returns the parsed out rule type and the remaining string
553    // |-------------|---...
554    // rule_type ----^     ^
555    // rem ----------------'
556    pub fn consume_input(input: &str) -> Result<(Self, &str), RuleParseError> {
557        let pair = if let Some(rem) = input.strip_prefix("dns:") {
558            (RuleType::Dns, rem)
559        } else if let Some(rem) = input.strip_prefix("ipv4:") {
560            (RuleType::IPV4, rem)
561        } else if let Some(rem) = input.strip_prefix("ipv6:") {
562            (RuleType::IPV6, rem)
563        } else {
564            return Err(RuleParseError::InvalidRuleType(input.to_string()));
565        };
566
567        Ok(pair)
568    }
569}
570
571// Represents the rule action section in a [`RulesetSegment`]
572#[derive(Debug, Clone, PartialEq, Eq)]
573enum RuleAction {
574    Allow,
575    Deny,
576}
577
578impl RuleAction {
579    // Receives a string as input and returns the parsed out rule action and the remaining string
580    // |----------|---------|---...
581    // rule_type -^         ^     ^
582    // rule_action ---------'     '
583    // rem -----------------------'
584    pub fn consume_input(input: &str) -> Result<(Self, &str), RuleParseError> {
585        let pair = if let Some(rem) = input.strip_prefix("allow=") {
586            (RuleAction::Allow, rem)
587        } else if let Some(rem) = input.strip_prefix("deny=") {
588            (RuleAction::Deny, rem)
589        } else {
590            return Err(RuleParseError::InvalidRuleAction(input.to_string()));
591        };
592
593        Ok(pair)
594    }
595}
596
597// Represents the rule expression section in a [`RulesetSegment`]
598#[derive(Debug, Clone, PartialEq, Eq)]
599struct RuleExpr(String);
600
601impl RuleExpr {
602    // Receives a string as input and returns the parsed out rule expression and the remaining string
603    // |----------|---------|-----|---...
604    // rule_type -^         ^     ^     ^
605    // rule_action ---------'     '     '
606    // rule_expr -----------------'     '
607    // rem -----------------------------'
608    pub fn consume_input(input: &str) -> Result<(Self, &str), RuleParseError> {
609        let mut next_dns_entry = usize::MAX;
610        let mut next_ipv4_entry = usize::MAX;
611        let mut next_ipv6_entry = usize::MAX;
612
613        if let Some(idx) = input.find(",dns:") {
614            next_dns_entry = idx;
615        }
616
617        if let Some(idx) = input.find(",ipv4:") {
618            next_ipv4_entry = idx;
619        }
620
621        if let Some(idx) = input.find(",ipv6:") {
622            next_ipv6_entry = idx;
623        }
624
625        let next_entry = next_dns_entry
626            .min(next_ipv4_entry)
627            .min(next_ipv6_entry)
628            .min(input.len());
629
630        let (expr, rem) = input.split_at(next_entry);
631
632        let rem = rem.strip_prefix(',').unwrap_or(rem);
633
634        Ok((RuleExpr(expr.to_string()), rem))
635    }
636}
637
638// A ruleset is a series of comma separated ruleset segments:
639//     <rule1>, <rule2>, ...
640// each rule is consistent of three sections:
641//     <rule-type>:<rule-action>=<rule-expr>
642#[derive(Debug, Clone, PartialEq, Eq)]
643struct RulesetSegment {
644    ty: RuleType,
645    action: RuleAction,
646    expr: RuleExpr,
647}
648
649fn parse_ruleset_segments(s: impl AsRef<str>) -> Result<Vec<RulesetSegment>, RuleParseError> {
650    let mut input = s.as_ref();
651    let mut segments = Vec::new();
652
653    while !input.is_empty() {
654        let (ty, remaining) = RuleType::consume_input(input)?;
655        let (action, remaining) = RuleAction::consume_input(remaining)?;
656        let (expr, remaining) = RuleExpr::consume_input(remaining)?;
657
658        segments.push(RulesetSegment { ty, action, expr });
659
660        input = remaining;
661    }
662
663    Ok(segments)
664}
665
666/// Represents a ruleset that can be used to specify a whitelist and a blacklist in order to
667/// control the inbound and outbound traffic of a network.
668#[derive(Debug, Clone)]
669pub struct Ruleset {
670    rules: Arc<RwLock<Vec<Rule>>>,
671}
672
673impl Ruleset {
674    /// Returns `true` if at least one rule allows accessing `socket_addr` in the specific `direction`
675    /// and no rule blocks it
676    pub fn allows_socket(&self, addr: impl Into<SocketAddr>, dir: Direction) -> bool {
677        let addr = addr.into();
678
679        {
680            let ruleset = self.rules.read().unwrap();
681
682            let is_blacklisted = ruleset.iter().any(|r| r.blocks_socket(addr, dir));
683            if is_blacklisted {
684                return false;
685            }
686
687            ruleset.iter().any(|r| r.allows_socket(addr, dir))
688        }
689    }
690
691    /// Returns `true` if at least one rule allows querying the specific `domain` and no rule blocks it
692    pub fn allows_domain(&self, domain: impl AsRef<str>) -> bool {
693        let domain = domain.as_ref();
694
695        {
696            let ruleset = self.rules.read().unwrap();
697
698            let is_blacklisted = ruleset.iter().any(|r| r.blocks_domain(domain));
699            if is_blacklisted {
700                return false;
701            }
702
703            ruleset.iter().any(|r| r.allows_domain(domain))
704        }
705    }
706
707    /// Expands the DNS rule that allows the specified `domain` into a list of IP based
708    /// rules with addresses specified by `addrs`
709    pub fn expand_domain(
710        &self,
711        domain: impl AsRef<str>,
712        addrs: impl AsRef<[IpAddr]>,
713    ) -> Result<(), RuleParseError> {
714        let mut ruleset = self.rules.write().unwrap();
715        let domain = domain.as_ref();
716
717        let mut already_expanded = false;
718        let port_spec = ruleset
719            .iter_mut()
720            .find_map(|rule| {
721                let port_spec = rule.port_spec_of_domain(domain);
722
723                if port_spec.is_some() {
724                    if rule.is_expandable() {
725                        rule.set_expanded(true);
726
727                        return port_spec;
728                    } else {
729                        already_expanded = true;
730                    }
731                }
732
733                None
734            })
735            .ok_or_else(|| {
736                if already_expanded {
737                    RuleParseError::DomainAlreadyExpanded(domain.to_string())
738                } else {
739                    RuleParseError::DomainRuleNotFound(domain.to_string())
740                }
741            })?;
742
743        for addr in addrs.as_ref() {
744            let rule = match addr {
745                IpAddr::V4(ip) => Rule::IPV4(IPV4Rule {
746                    ip_spec: IPV4Spec::IP(*ip),
747                    port_spec: port_spec.clone(),
748                    direction: Direction::Outbound,
749                }),
750                IpAddr::V6(ip) => Rule::IPV6(IPV6Rule {
751                    ip_spec: IPV6Spec::IP(*ip),
752                    port_spec: port_spec.clone(),
753                    direction: Direction::Outbound,
754                }),
755            };
756
757            ruleset.push(rule);
758        }
759
760        Ok(())
761    }
762}
763
764impl FromStr for Ruleset {
765    type Err = RuleParseError;
766
767    fn from_str(s: &str) -> Result<Self, Self::Err> {
768        let s: String = s.chars().filter(|c| !c.is_whitespace()).collect();
769        let mut rules = vec![];
770        for seg in parse_ruleset_segments(s)? {
771            let rule_type = &seg.ty;
772            let rule_action = &seg.action;
773            let rule_expr = &seg.expr;
774
775            let parsed_rules: Vec<Rule> = match rule_type {
776                RuleType::Dns => parse_dns_rule(&rule_expr.0)?
777                    .into_iter()
778                    .map(Rule::DNS)
779                    .collect(),
780                RuleType::IPV4 => parse_ipv4_rule(&rule_expr.0)?
781                    .into_iter()
782                    .map(Rule::IPV4)
783                    .collect(),
784                RuleType::IPV6 => parse_ipv6_rule(&rule_expr.0)?
785                    .into_iter()
786                    .map(Rule::IPV6)
787                    .collect(),
788            };
789
790            let parsed_rules = match rule_action {
791                RuleAction::Allow => parsed_rules,
792                RuleAction::Deny => parsed_rules
793                    .into_iter()
794                    .map(|rule| Rule::Neg(Arc::new(rule)))
795                    .collect(),
796            };
797
798            rules.extend(parsed_rules);
799        }
800
801        Ok(Self {
802            rules: Arc::new(RwLock::new(rules)),
803        })
804    }
805}
806
807#[cfg(test)]
808mod tests {
809    use super::*;
810
811    #[test]
812    fn all_ports_spec() {
813        let spec = PortSpec::from_str("*").unwrap();
814
815        assert!(spec.matches(80));
816    }
817
818    #[test]
819    fn port_spec() {
820        let spec = PortSpec::from_str("80").unwrap();
821
822        assert!(spec.matches(80));
823        assert!(!spec.matches(443));
824    }
825
826    #[test]
827    fn port_range_spec() {
828        let spec = PortSpec::from_str("80-85").unwrap();
829
830        assert!(!spec.matches(79));
831        assert!(spec.matches(80));
832        assert!(spec.matches(81));
833        assert!(spec.matches(82));
834        assert!(spec.matches(83));
835        assert!(spec.matches(84));
836        assert!(spec.matches(85));
837        assert!(!spec.matches(86));
838    }
839
840    #[test]
841    fn all_domains_spec() {
842        let spec = DomainSpec::from_str("*").unwrap();
843
844        assert!(spec.matches("example.com"));
845    }
846
847    #[test]
848    fn domain_spec() {
849        let spec = DomainSpec::from_str("example.com").unwrap();
850
851        assert!(spec.matches("example.com"));
852        assert!(!spec.matches("sub.example.com"));
853        assert!(!spec.matches("test.com"));
854    }
855
856    #[test]
857    fn domain_glob_spec() {
858        let spec = DomainSpec::from_str("*.example.com").unwrap();
859
860        assert!(!spec.matches("example.com"));
861        assert!(spec.matches("sub.example.com"));
862        assert!(spec.matches("another.sub.example.com"));
863        assert!(!spec.matches("test.com"));
864    }
865
866    #[test]
867    fn all_ipv4s_spec() {
868        let spec = IPV4Spec::from_str("*").unwrap();
869
870        assert!(spec.matches([127, 0, 0, 1]));
871    }
872
873    #[test]
874    fn ipv4_spec() {
875        let spec = IPV4Spec::from_str("127.0.0.1").unwrap();
876
877        assert!(spec.matches([127, 0, 0, 1]));
878        assert!(!spec.matches([192, 168, 1, 1]));
879    }
880
881    #[test]
882    fn ipv4_range_spec() {
883        let rule = IPV4Spec::from_str("192.168.1.0/24").unwrap();
884
885        let matches = vec![
886            "192.168.1.1",
887            "192.168.1.0",
888            "192.168.1.255",
889            "192.168.1.100",
890            "192.168.1.50",
891        ];
892
893        let non_matches = vec![
894            "192.168.2.0",
895            "192.167.1.1",
896            "10.0.0.1",
897            "172.16.0.1",
898            "192.168.0.255",
899        ];
900
901        for ip in matches {
902            let ip_addr: Ipv4Addr = ip.parse().unwrap();
903            assert!(rule.matches(ip_addr));
904        }
905
906        for ip in non_matches {
907            let ip_addr: Ipv4Addr = ip.parse().unwrap();
908            assert!(!rule.matches(ip_addr));
909        }
910    }
911
912    #[test]
913    fn all_ipv6s_spec() {
914        let spec = IPV6Spec::from_str("*").unwrap();
915
916        assert!(spec.matches("2001:db8::1".parse().unwrap()));
917    }
918
919    #[test]
920    fn ipv6_spec() {
921        let spec = IPV6Spec::from_str("2001:db8::1").unwrap();
922
923        assert!(spec.matches("2001:db8::1".parse().unwrap()));
924        assert!(!spec.matches("2001:db7::1".parse().unwrap()));
925    }
926
927    #[test]
928    fn ipv6_range_spec() {
929        let spec = IPV6Spec::from_str("2001:db8::/32").unwrap();
930
931        let matches = vec![
932            "2001:db8::1",
933            "2001:db8::",
934            "2001:db8:0:0:0:0:0:1234",
935            "2001:db8::abcd",
936            "2001:db8::ffff",
937        ];
938
939        let non_matches = vec![
940            "2001:db9::",
941            "2001:db7::1",
942            "2001:dead::1",
943            "fe80::1",
944            "::1",
945        ];
946
947        for ip in matches {
948            let ip_addr: Ipv6Addr = ip.parse().unwrap();
949            assert!(spec.matches(ip_addr));
950        }
951
952        for ip in non_matches {
953            let ip_addr: Ipv6Addr = ip.parse().unwrap();
954            assert!(!spec.matches(ip_addr));
955        }
956    }
957
958    #[test]
959    fn dns_rule_all() {
960        let rules = parse_dns_rule("*:*").unwrap();
961
962        assert_eq!(rules.len(), 1);
963        assert!(rules[0].allows("example.com"));
964        assert_eq!(rules[0].allowed_ports(), PortSpec::All);
965    }
966
967    #[test]
968    fn dns_rule_single_domain_and_port() {
969        let rules = parse_dns_rule("example.com:80").unwrap();
970
971        assert_eq!(rules.len(), 1);
972        assert!(rules[0].allows("example.com"));
973        assert_eq!(rules[0].allowed_ports(), PortSpec::Port(80));
974    }
975
976    #[test]
977    fn dns_rule_multiple_domain_and_ports() {
978        let mut rules = parse_dns_rule("{a.com, *.b.com}:{80, 100-200}").unwrap();
979
980        let rule1 = rules.pop().unwrap(); // *.b.com:100-200
981        let rule2 = rules.pop().unwrap(); // *.b.com:80
982        let rule3 = rules.pop().unwrap(); // a.com:100-200
983        let rule4 = rules.pop().unwrap(); // a.com:80
984
985        assert!(rules.is_empty());
986
987        assert!(rule1.allows("sub.b.com"));
988        assert!(!rule1.allows("b.com"));
989        assert!(!rule1.allows("a.com"));
990        assert_eq!(rule1.allowed_ports(), PortSpec::PortRange(100..=200));
991
992        assert!(rule2.allows("sub.b.com"));
993        assert!(!rule2.allows("b.com"));
994        assert!(!rule2.allows("a.com"));
995        assert_eq!(rule2.allowed_ports(), PortSpec::Port(80));
996
997        assert!(rule3.allows("a.com"));
998        assert!(!rule3.allows("sub.a.com"));
999        assert!(!rule3.allows("b.com"));
1000        assert_eq!(rule3.allowed_ports(), PortSpec::PortRange(100..=200));
1001
1002        assert!(rule4.allows("a.com"));
1003        assert!(!rule4.allows("sub.a.com"));
1004        assert!(!rule4.allows("b.com"));
1005        assert_eq!(rule4.allowed_ports(), PortSpec::Port(80));
1006    }
1007
1008    #[test]
1009    fn ipv4_rule_all() {
1010        let rules = parse_ipv4_rule("*:*").unwrap();
1011
1012        assert_eq!(rules.len(), 1);
1013        assert!(rules[0].is_allowed([127, 0, 0, 1], 80, Direction::Inbound));
1014        assert!(rules[0].is_allowed([127, 0, 0, 1], 80, Direction::Outbound));
1015    }
1016
1017    #[test]
1018    fn ipv4_rule_single_ip_all_ports_inbound() {
1019        let rules = parse_ipv4_rule("127.0.0.1:*/in").unwrap();
1020
1021        assert_eq!(rules.len(), 1);
1022        assert!(rules[0].is_allowed([127, 0, 0, 1], 80, Direction::Inbound));
1023        assert!(!rules[0].is_allowed([127, 0, 0, 1], 80, Direction::Outbound));
1024        assert!(!rules[0].is_allowed([192, 168, 1, 2], 80, Direction::Inbound));
1025        assert!(!rules[0].is_allowed([192, 168, 1, 2], 80, Direction::Outbound));
1026    }
1027
1028    #[test]
1029    fn ipv4_rule_ip_range_all_ports_outbound() {
1030        let mut rules = parse_ipv4_rule("192.168.1.0/24:*/out").unwrap();
1031
1032        let ip_matches = vec![
1033            "192.168.1.1",
1034            "192.168.1.0",
1035            "192.168.1.255",
1036            "192.168.1.100",
1037            "192.168.1.50",
1038        ];
1039
1040        let ip_non_matches = vec![
1041            "192.168.2.0",
1042            "192.167.1.1",
1043            "10.0.0.1",
1044            "172.16.0.1",
1045            "192.168.0.255",
1046        ];
1047
1048        assert_eq!(rules.len(), 1);
1049        let rule = rules.pop().unwrap();
1050
1051        for ip in &ip_matches {
1052            let ip_addr: Ipv4Addr = ip.parse().unwrap();
1053            assert!(rule.is_allowed(ip_addr, 8080, Direction::Outbound));
1054        }
1055        // direction is wrong
1056        for ip in &ip_matches {
1057            let ip_addr: Ipv4Addr = ip.parse().unwrap();
1058            assert!(!rule.is_allowed(ip_addr, 8080, Direction::Inbound));
1059        }
1060        // ip is wrong
1061        for ip in &ip_non_matches {
1062            let ip_addr: Ipv4Addr = ip.parse().unwrap();
1063            assert!(!rule.is_allowed(ip_addr, 8080, Direction::Inbound));
1064        }
1065    }
1066
1067    #[test]
1068    fn ipv4_rule_all_ip_port_range_outbound() {
1069        let rules = parse_ipv4_rule("*:80-90/out").unwrap();
1070
1071        assert_eq!(rules.len(), 1);
1072        assert!(!rules[0].is_allowed([127, 0, 0, 1], 80, Direction::Inbound));
1073        assert!(rules[0].is_allowed([127, 0, 0, 1], 80, Direction::Outbound));
1074        assert!(rules[0].is_allowed([127, 0, 0, 1], 85, Direction::Outbound));
1075        assert!(rules[0].is_allowed([127, 0, 0, 1], 90, Direction::Outbound));
1076        assert!(!rules[0].is_allowed([127, 0, 0, 1], 443, Direction::Outbound));
1077        assert!(!rules[0].is_allowed([192, 168, 1, 2], 80, Direction::Inbound));
1078        assert!(rules[0].is_allowed([192, 168, 1, 2], 80, Direction::Outbound));
1079    }
1080
1081    #[test]
1082    fn multiple_ipv4_rules() {
1083        let mut rules = parse_ipv4_rule("{127.0.0.1, 192.168.1.0/24}:{80, 8080}/in").unwrap();
1084
1085        let rule1 = rules.pop().unwrap(); // 192.168.1.0/24:8080/in
1086        let rule2 = rules.pop().unwrap(); // 192.168.1.0/24:80/in
1087        let rule3 = rules.pop().unwrap(); // 127.0.0.1:8080/in
1088        let rule4 = rules.pop().unwrap(); // 127.0.0.1:80/in
1089
1090        assert!(rules.is_empty());
1091
1092        let ip_matches = vec![
1093            "192.168.1.1",
1094            "192.168.1.0",
1095            "192.168.1.255",
1096            "192.168.1.100",
1097            "192.168.1.50",
1098        ];
1099
1100        let ip_non_matches = vec![
1101            "192.168.2.0",
1102            "192.167.1.1",
1103            "10.0.0.1",
1104            "172.16.0.1",
1105            "192.168.0.255",
1106        ];
1107
1108        // rule1
1109        for ip in &ip_matches {
1110            let ip_addr: Ipv4Addr = ip.parse().unwrap();
1111            assert!(rule1.is_allowed(ip_addr, 8080, Direction::Inbound));
1112        }
1113        // direction is wrong
1114        for ip in &ip_matches {
1115            let ip_addr: Ipv4Addr = ip.parse().unwrap();
1116            assert!(!rule1.is_allowed(ip_addr, 8080, Direction::Outbound));
1117        }
1118        // port is wrong
1119        for ip in &ip_matches {
1120            let ip_addr: Ipv4Addr = ip.parse().unwrap();
1121            assert!(!rule1.is_allowed(ip_addr, 80, Direction::Inbound));
1122        }
1123        // ip is wrong
1124        for ip in &ip_non_matches {
1125            let ip_addr: Ipv4Addr = ip.parse().unwrap();
1126            assert!(!rule1.is_allowed(ip_addr, 8080, Direction::Inbound));
1127        }
1128
1129        // rule2
1130        for ip in &ip_matches {
1131            let ip_addr: Ipv4Addr = ip.parse().unwrap();
1132            assert!(rule2.is_allowed(ip_addr, 80, Direction::Inbound));
1133        }
1134        // direction is wrong
1135        for ip in &ip_matches {
1136            let ip_addr: Ipv4Addr = ip.parse().unwrap();
1137            assert!(!rule2.is_allowed(ip_addr, 80, Direction::Outbound));
1138        }
1139        // port is wrong
1140        for ip in &ip_matches {
1141            let ip_addr: Ipv4Addr = ip.parse().unwrap();
1142            assert!(!rule2.is_allowed(ip_addr, 8080, Direction::Inbound));
1143        }
1144        // ip is wrong
1145        for ip in &ip_non_matches {
1146            let ip_addr: Ipv4Addr = ip.parse().unwrap();
1147            assert!(!rule2.is_allowed(ip_addr, 80, Direction::Inbound));
1148        }
1149
1150        // rule3
1151        assert!(rule3.is_allowed([127, 0, 0, 1], 8080, Direction::Inbound));
1152        assert!(!rule3.is_allowed([192, 168, 1, 2], 8080, Direction::Inbound));
1153        assert!(!rule3.is_allowed([127, 0, 0, 1], 80, Direction::Inbound));
1154        assert!(!rule3.is_allowed([127, 0, 0, 1], 8080, Direction::Outbound));
1155
1156        // rule4
1157        assert!(rule4.is_allowed([127, 0, 0, 1], 80, Direction::Inbound));
1158        assert!(!rule4.is_allowed([192, 168, 1, 2], 80, Direction::Inbound));
1159        assert!(!rule4.is_allowed([127, 0, 0, 1], 8080, Direction::Inbound));
1160        assert!(!rule4.is_allowed([127, 0, 0, 1], 80, Direction::Outbound));
1161    }
1162
1163    #[test]
1164    fn ipv6_rule_all() {
1165        let rules = parse_ipv6_rule("*:*").unwrap();
1166
1167        assert_eq!(rules.len(), 1);
1168        assert!(rules[0].is_allowed(
1169            "2001:db8::1".parse::<Ipv6Addr>().unwrap(),
1170            80,
1171            Direction::Inbound
1172        ));
1173        assert!(rules[0].is_allowed(
1174            "2001:db8::1".parse::<Ipv6Addr>().unwrap(),
1175            80,
1176            Direction::Outbound
1177        ));
1178    }
1179
1180    #[test]
1181    fn ipv6_rule_single_ip_and_port() {
1182        let rules = parse_ipv6_rule("[2001:db8::1]:80").unwrap();
1183
1184        assert_eq!(rules.len(), 1);
1185        assert!(rules[0].is_allowed(
1186            "2001:db8::1".parse::<Ipv6Addr>().unwrap(),
1187            80,
1188            Direction::Inbound
1189        ));
1190        assert!(rules[0].is_allowed(
1191            "2001:db8::1".parse::<Ipv6Addr>().unwrap(),
1192            80,
1193            Direction::Outbound
1194        ));
1195    }
1196
1197    #[test]
1198    fn ipv6_rule_single_ip_all_ports_inbound() {
1199        let rules = parse_ipv6_rule("[2001:db8::1]:*/in").unwrap();
1200
1201        assert_eq!(rules.len(), 1);
1202        assert!(rules[0].is_allowed(
1203            "2001:db8::1".parse::<Ipv6Addr>().unwrap(),
1204            80,
1205            Direction::Inbound
1206        ));
1207        assert!(!rules[0].is_allowed(
1208            "2002:db8::1".parse::<Ipv6Addr>().unwrap(),
1209            80,
1210            Direction::Inbound
1211        ));
1212        assert!(!rules[0].is_allowed(
1213            "2001:db8::1".parse::<Ipv6Addr>().unwrap(),
1214            8080,
1215            Direction::Outbound
1216        ));
1217    }
1218
1219    #[test]
1220    fn ipv6_rule_ip_range_all_ports_outbound() {
1221        let mut rules = parse_ipv6_rule("[2001:db8::/32]:*/out").unwrap();
1222
1223        let ip_matches = vec![
1224            "2001:db8::1",
1225            "2001:db8::",
1226            "2001:db8:0:0:0:0:0:1234",
1227            "2001:db8::abcd",
1228            "2001:db8::ffff",
1229        ];
1230
1231        let ip_non_matches = vec![
1232            "2001:db9::",
1233            "2001:db7::1",
1234            "2001:dead::1",
1235            "fe80::1",
1236            "::1",
1237        ];
1238
1239        assert_eq!(rules.len(), 1);
1240        let rule = rules.pop().unwrap();
1241
1242        for ip in &ip_matches {
1243            let ip_addr: Ipv6Addr = ip.parse().unwrap();
1244            assert!(rule.is_allowed(ip_addr, 8080, Direction::Outbound));
1245        }
1246        // direction is wrong
1247        for ip in &ip_matches {
1248            let ip_addr: Ipv6Addr = ip.parse().unwrap();
1249            assert!(!rule.is_allowed(ip_addr, 8080, Direction::Inbound));
1250        }
1251        // ip is wrong
1252        for ip in &ip_non_matches {
1253            let ip_addr: Ipv6Addr = ip.parse().unwrap();
1254            assert!(!rule.is_allowed(ip_addr, 8080, Direction::Inbound));
1255        }
1256    }
1257
1258    #[test]
1259    fn multiple_ipv6_rules() {
1260        let mut rules = parse_ipv6_rule("{3001:db8::, 2001:db8::/32}:{80, 8080}/in").unwrap();
1261
1262        let rule1 = rules.pop().unwrap(); // [2001:db8::/32]:8080/in
1263        let rule2 = rules.pop().unwrap(); // [2001:db8::/32]:80/in
1264        let rule3 = rules.pop().unwrap(); // [3001:db8::]:8080/in
1265        let rule4 = rules.pop().unwrap(); // [3001:db8::]:80/in
1266
1267        assert!(rules.is_empty());
1268
1269        let ip_matches = vec![
1270            "2001:db8::1",
1271            "2001:db8::",
1272            "2001:db8:0:0:0:0:0:1234",
1273            "2001:db8::abcd",
1274            "2001:db8::ffff",
1275        ];
1276
1277        let ip_non_matches = vec![
1278            "2001:db9::",
1279            "2001:db7::1",
1280            "2001:dead::1",
1281            "fe80::1",
1282            "::1",
1283        ];
1284
1285        // rule1
1286        for ip in &ip_matches {
1287            let ip_addr: Ipv6Addr = ip.parse().unwrap();
1288            assert!(rule1.is_allowed(ip_addr, 8080, Direction::Inbound));
1289        }
1290        // direction is wrong
1291        for ip in &ip_matches {
1292            let ip_addr: Ipv6Addr = ip.parse().unwrap();
1293            assert!(!rule1.is_allowed(ip_addr, 8080, Direction::Outbound));
1294        }
1295        // port is wrong
1296        for ip in &ip_matches {
1297            let ip_addr: Ipv6Addr = ip.parse().unwrap();
1298            assert!(!rule1.is_allowed(ip_addr, 80, Direction::Inbound));
1299        }
1300        // ip is wrong
1301        for ip in &ip_non_matches {
1302            let ip_addr: Ipv6Addr = ip.parse().unwrap();
1303            assert!(!rule1.is_allowed(ip_addr, 8080, Direction::Inbound));
1304        }
1305
1306        // rule2
1307        for ip in &ip_matches {
1308            let ip_addr: Ipv6Addr = ip.parse().unwrap();
1309            assert!(rule2.is_allowed(ip_addr, 80, Direction::Inbound));
1310        }
1311        // direction is wrong
1312        for ip in &ip_matches {
1313            let ip_addr: Ipv6Addr = ip.parse().unwrap();
1314            assert!(!rule2.is_allowed(ip_addr, 80, Direction::Outbound));
1315        }
1316        // port is wrong
1317        for ip in &ip_matches {
1318            let ip_addr: Ipv6Addr = ip.parse().unwrap();
1319            assert!(!rule2.is_allowed(ip_addr, 8080, Direction::Inbound));
1320        }
1321        // ip is wrong
1322        for ip in &ip_non_matches {
1323            let ip_addr: Ipv6Addr = ip.parse().unwrap();
1324            assert!(!rule2.is_allowed(ip_addr, 80, Direction::Inbound));
1325        }
1326
1327        // rule3
1328        assert!(rule3.is_allowed(
1329            "3001:db8::".parse::<Ipv6Addr>().unwrap(),
1330            8080,
1331            Direction::Inbound
1332        ));
1333        assert!(!rule3.is_allowed(
1334            "4001:db8::".parse::<Ipv6Addr>().unwrap(),
1335            8080,
1336            Direction::Inbound
1337        ));
1338        assert!(!rule3.is_allowed(
1339            "3001:db8::".parse::<Ipv6Addr>().unwrap(),
1340            80,
1341            Direction::Inbound
1342        ));
1343        assert!(!rule3.is_allowed(
1344            "3001:db8::".parse::<Ipv6Addr>().unwrap(),
1345            8080,
1346            Direction::Outbound
1347        ));
1348
1349        // rule4
1350        assert!(rule4.is_allowed(
1351            "3001:db8::".parse::<Ipv6Addr>().unwrap(),
1352            80,
1353            Direction::Inbound
1354        ));
1355        assert!(!rule4.is_allowed(
1356            "4001:db8::".parse::<Ipv6Addr>().unwrap(),
1357            80,
1358            Direction::Inbound
1359        ));
1360        assert!(!rule4.is_allowed(
1361            "3001:db8::".parse::<Ipv6Addr>().unwrap(),
1362            8080,
1363            Direction::Inbound
1364        ));
1365        assert!(!rule4.is_allowed(
1366            "3001:db8::".parse::<Ipv6Addr>().unwrap(),
1367            80,
1368            Direction::Outbound
1369        ));
1370    }
1371
1372    #[test]
1373    fn ruleset_dns() {
1374        let ruleset = Ruleset::from_str("dns:allow={a.com, *.b.com}:{80, 8080}").unwrap();
1375
1376        assert!(ruleset.allows_domain("a.com"));
1377        assert!(!ruleset.allows_domain("sub.a.com"));
1378        assert!(!ruleset.allows_domain("b.com"));
1379        assert!(ruleset.allows_domain("sub.b.com"));
1380        assert!(ruleset.allows_domain("another.sub.b.com"));
1381    }
1382
1383    #[test]
1384    fn ruleset_ipv4() {
1385        let ruleset =
1386            Ruleset::from_str("ipv4:deny={127.0.0.1, 192.168.1.0/24}:{80, 8080}/in").unwrap();
1387
1388        let ip_matches = vec![
1389            "192.168.1.1",
1390            "192.168.1.0",
1391            "192.168.1.255",
1392            "192.168.1.100",
1393            "192.168.1.50",
1394        ];
1395
1396        for ip in &ip_matches {
1397            let ip_addr: Ipv4Addr = ip.parse().unwrap();
1398            assert!(!ruleset.allows_socket((ip_addr, 8080), Direction::Inbound));
1399        }
1400
1401        assert!(!ruleset.allows_socket(([127, 0, 0, 1], 8080), Direction::Inbound));
1402        assert!(!ruleset.allows_socket(([127, 0, 0, 1], 80), Direction::Inbound));
1403    }
1404
1405    #[test]
1406    fn ruleset_ipv6() {
1407        let ruleset =
1408            Ruleset::from_str("ipv6:allow={3001:db8::, 2001:db8::/32}:{80, 8080}/in").unwrap();
1409
1410        let ip_matches = vec![
1411            "2001:db8::1",
1412            "2001:db8::",
1413            "2001:db8:0:0:0:0:0:1234",
1414            "2001:db8::abcd",
1415            "2001:db8::ffff",
1416        ];
1417
1418        for ip in &ip_matches {
1419            let ip_addr: Ipv6Addr = ip.parse().unwrap();
1420            assert!(ruleset.allows_socket((ip_addr, 8080), Direction::Inbound));
1421        }
1422
1423        assert!(ruleset.allows_socket(
1424            ("3001:db8::".parse::<Ipv6Addr>().unwrap(), 8080),
1425            Direction::Inbound
1426        ));
1427        assert!(ruleset.allows_socket(
1428            ("3001:db8::".parse::<Ipv6Addr>().unwrap(), 8080),
1429            Direction::Inbound
1430        ));
1431    }
1432
1433    #[test]
1434    fn ruleset_full() {
1435        let ruleset = Ruleset::from_str(
1436            "dns:allow={a.com, *.b.com}:{80, 8080},
1437            ipv4:deny={127.0.0.1, 192.168.1.0/24}:{80, 8080}/in,
1438            ipv6:allow={3001:db8::, 2001:db8::/32}:{80, 8080}/in",
1439        )
1440        .unwrap();
1441
1442        // dns rules
1443        assert!(ruleset.allows_domain("a.com"));
1444        assert!(!ruleset.allows_domain("sub.a.com"));
1445        assert!(!ruleset.allows_domain("b.com"));
1446        assert!(ruleset.allows_domain("sub.b.com"));
1447        assert!(ruleset.allows_domain("another.sub.b.com"));
1448
1449        // ipv4 rules
1450        let ip_matches = vec![
1451            "192.168.1.1",
1452            "192.168.1.0",
1453            "192.168.1.255",
1454            "192.168.1.100",
1455            "192.168.1.50",
1456        ];
1457
1458        for ip in &ip_matches {
1459            let ip_addr: Ipv4Addr = ip.parse().unwrap();
1460            assert!(!ruleset.allows_socket((ip_addr, 8080), Direction::Inbound));
1461        }
1462
1463        assert!(!ruleset.allows_socket(([127, 0, 0, 1], 8080), Direction::Inbound));
1464        assert!(!ruleset.allows_socket(([127, 0, 0, 1], 80), Direction::Inbound));
1465
1466        // ipv6 rules
1467        let ip_matches = vec![
1468            "2001:db8::1",
1469            "2001:db8::",
1470            "2001:db8:0:0:0:0:0:1234",
1471            "2001:db8::abcd",
1472            "2001:db8::ffff",
1473        ];
1474
1475        for ip in &ip_matches {
1476            let ip_addr: Ipv6Addr = ip.parse().unwrap();
1477            assert!(ruleset.allows_socket((ip_addr, 8080), Direction::Inbound));
1478        }
1479
1480        assert!(ruleset.allows_socket(
1481            ("3001:db8::".parse::<Ipv6Addr>().unwrap(), 8080),
1482            Direction::Inbound
1483        ));
1484        assert!(ruleset.allows_socket(
1485            ("3001:db8::".parse::<Ipv6Addr>().unwrap(), 8080),
1486            Direction::Inbound
1487        ));
1488    }
1489}