Browse Source

Easy macro to try multiple parsing functions and new error to differentiate between an incorrect AML stream and a failed parse

Isaac Woods 6 years ago
parent
commit
a8f7c291f4
2 changed files with 55 additions and 46 deletions
  1. 1 0
      src/aml/mod.rs
  2. 54 46
      src/aml/parser.rs

+ 1 - 0
src/aml/mod.rs

@@ -23,6 +23,7 @@ pub struct AmlTable {
 #[derive(Debug)]
 pub enum AmlError {
     EndOfStream,
+    NotAnX,
     UnexpectedByte(u8),
     IncompatibleValueConversion,
     InvalidPath(String),

+ 54 - 46
src/aml/parser.rs

@@ -16,6 +16,22 @@ where
     stream: AmlStream<'s>,
 }
 
+/// This macro takes a parser and one or more parsing functions and tries to parse the next part of
+/// the stream with each one. If a parsing function fails, it rolls back the stream and tries the
+/// next one. If none of the functions can parse the next part of the stream, we error on the
+/// unexpected byte.
+macro_rules! try_parse {
+    ($parser: expr, $($function: path),+) => {
+        if false {
+            unreachable!();
+        } $(else if let Some(value) = $parser.try_parse($function)? {
+                Ok(value)
+        })+ else {
+            Err(AmlError::UnexpectedByte($parser.stream.peek()?))
+        }
+    };
+}
+
 impl<'s, 'a, 'h, H> AmlParser<'s, 'a, 'h, H>
 where
     'h: 'a,
@@ -36,7 +52,11 @@ where
         parser.parse_term_list(end_offset)
     }
 
-    fn consume_byte<F>(&mut self, predicate: F) -> Result<u8, AmlError>
+    /// This consumes the next byte in the stream, checking if it fulfils the given predicate. If
+    /// it does, this returns `Ok(the consumed char)`. If it doesn't and we are `checking` if we are
+    /// parsing an 'x', we return `Err(AmlError::NotAnX)`, while if we expect the char to pass the
+    /// predicate, we return `Err(AmlError::UnexpectedByte(the consumed char))`.
+    fn consume_byte<F>(&mut self, predicate: F, checking: bool) -> Result<u8, AmlError>
     where
         F: Fn(u8) -> bool,
     {
@@ -44,18 +64,22 @@ where
 
         match predicate(byte) {
             true => Ok(byte),
-            false => Err(AmlError::UnexpectedByte(byte)),
+            false => if checking {
+                Err(AmlError::NotAnX)
+            } else {
+                Err(AmlError::UnexpectedByte(byte))
+            },
         }
     }
 
-    fn consume_opcode(&mut self, opcode: u8) -> Result<(), AmlError> {
-        self.consume_byte(matches_byte(opcode))?;
+    fn consume_opcode(&mut self, opcode: u8, checking: bool) -> Result<(), AmlError> {
+        self.consume_byte(matches_byte(opcode), checking)?;
         Ok(())
     }
 
-    fn consume_ext_opcode(&mut self, ext_opcode: u8) -> Result<(), AmlError> {
-        self.consume_byte(matches_byte(opcodes::EXT_OPCODE_PREFIX))?;
-        self.consume_byte(matches_byte(ext_opcode))?;
+    fn consume_ext_opcode(&mut self, ext_opcode: u8, checking: bool) -> Result<(), AmlError> {
+        self.consume_byte(matches_byte(opcodes::EXT_OPCODE_PREFIX), checking)?;
+        self.consume_byte(matches_byte(ext_opcode), checking)?;
         Ok(())
     }
 
@@ -72,13 +96,7 @@ where
         match parsing_function(self) {
             Ok(result) => Ok(Some(result)),
 
-            /*
-             * TODO: What about two separate error types here, one to say "this is not an 'x'"
-             * (Return Ok(None)) and
-             * one to say "this is an 'x' but I didn't expect this byte!" (Return
-             * Err(UnexpectedByte))
-             */
-            Err(AmlError::UnexpectedByte(_)) => {
+            Err(AmlError::NotAnX) => {
                 self.stream = stream;
                 Ok(None)
             }
@@ -94,10 +112,6 @@ where
          * Because TermLists don't have PkgLengths, we pass the offset to stop at from whatever
          * explicit-length object we were parsing before.
          */
-
-        /*
-         * We parse until we reach the offset marked as the end of the structure - `end_offset`
-         */
         while self.stream.offset() <= end_offset {
             self.parse_term_object()?;
         }
@@ -106,7 +120,6 @@ where
     }
 
     fn parse_term_object(&mut self) -> Result<(), AmlError> {
-        trace!("Parsing term object");
         /*
          * TermObj := NameSpaceModifierObj | NamedObj | Type1Opcode | Type2Opcode
          * NameSpaceModifierObj := DefAlias | DefName | DefScope
@@ -114,25 +127,20 @@ where
          *             DefCreateField | DefCreateQWordField | DefCreateWordField | DefDataRegion |
          *             DefExternal | DefOpRegion | DefPowerRes | DefProcessor | DefThermalZone
          */
-
-        if let Some(_) = self.try_parse(AmlParser::parse_def_scope)? {
-            Ok(())
-        } else if let Some(_) = self.try_parse(AmlParser::parse_def_op_region)? {
-            Ok(())
-        } else if let Some(_) = self.try_parse(AmlParser::parse_def_field)? {
-            Ok(())
-        } else if let Some(_) = self.try_parse(AmlParser::parse_type1_opcode)? {
-            Ok(())
-        } else {
-            Err(AmlError::UnexpectedByte(self.stream.peek()?))
-        }
+        try_parse!(
+            self,
+            AmlParser::parse_def_scope,
+            AmlParser::parse_def_op_region,
+            AmlParser::parse_def_field //,
+                                       // AmlParser::parse_type1_opcode    TODO: reenable when we can parse them
+        )
     }
 
     fn parse_def_scope(&mut self) -> Result<(), AmlError> {
         /*
          * DefScope := 0x10 PkgLength NameString TermList
          */
-        self.consume_opcode(opcodes::SCOPE_OP)?;
+        self.consume_opcode(opcodes::SCOPE_OP, true)?;
         trace!("Parsing scope op");
         let scope_end_offset = self.parse_pkg_length()?;
 
@@ -164,7 +172,7 @@ where
          * RegionOffset := TermArg => Integer
          * RegionLen := TermArg => Integer
          */
-        self.consume_ext_opcode(opcodes::EXT_OP_REGION_OP)?;
+        self.consume_ext_opcode(opcodes::EXT_OP_REGION_OP, true)?;
         info!("Parsing op region");
 
         let name = self.parse_name_string()?;
@@ -208,7 +216,7 @@ where
         /*
          * DefField = ExtOpPrefix 0x81 PkgLength NameString FieldFlags FieldList
          */
-        self.consume_ext_opcode(opcodes::EXT_FIELD_OP)?;
+        self.consume_ext_opcode(opcodes::EXT_FIELD_OP, true)?;
         let end_offset = self.parse_pkg_length()?;
         let name = self.parse_name_string()?;
         let field_flags = self.stream.next()?;
@@ -271,42 +279,42 @@ where
          */
         match self.stream.peek()? {
             opcodes::BYTE_CONST => {
-                self.consume_opcode(opcodes::BYTE_CONST)?;
+                self.consume_opcode(opcodes::BYTE_CONST, false)?;
                 Ok(AmlValue::Integer(self.stream.next()? as u64))
             }
 
             opcodes::WORD_CONST => {
-                self.consume_opcode(opcodes::WORD_CONST)?;
+                self.consume_opcode(opcodes::WORD_CONST, false)?;
                 Ok(AmlValue::Integer(self.stream.next_u16()? as u64))
             }
 
             opcodes::DWORD_CONST => {
-                self.consume_opcode(opcodes::DWORD_CONST)?;
+                self.consume_opcode(opcodes::DWORD_CONST, false)?;
                 Ok(AmlValue::Integer(self.stream.next_u16()? as u64))
             }
 
             opcodes::QWORD_CONST => {
-                self.consume_opcode(opcodes::QWORD_CONST)?;
+                self.consume_opcode(opcodes::QWORD_CONST, false)?;
                 Ok(AmlValue::Integer(self.stream.next_u16()? as u64))
             }
 
             opcodes::STRING_PREFIX => {
-                self.consume_opcode(opcodes::STRING_PREFIX)?;
+                self.consume_opcode(opcodes::STRING_PREFIX, false)?;
                 unimplemented!(); // TODO
             }
 
             opcodes::ZERO_OP => {
-                self.consume_opcode(opcodes::ZERO_OP)?;
+                self.consume_opcode(opcodes::ZERO_OP, false)?;
                 Ok(AmlValue::Integer(0))
             }
 
             opcodes::ONE_OP => {
-                self.consume_opcode(opcodes::ONE_OP)?;
+                self.consume_opcode(opcodes::ONE_OP, false)?;
                 Ok(AmlValue::Integer(1))
             }
 
             opcodes::ONES_OP => {
-                self.consume_opcode(opcodes::ONES_OP)?;
+                self.consume_opcode(opcodes::ONES_OP, false)?;
                 Ok(AmlValue::Integer(u64::max_value()))
             }
 
@@ -420,10 +428,10 @@ where
          * NameSeg := <LeadNameChar NameChar NameChar NameChar>
          */
         Ok([
-            self.consume_byte(is_lead_name_char)?,
-            self.consume_byte(is_name_char)?,
-            self.consume_byte(is_name_char)?,
-            self.consume_byte(is_name_char)?,
+            self.consume_byte(is_lead_name_char, false)?,
+            self.consume_byte(is_name_char, false)?,
+            self.consume_byte(is_name_char, false)?,
+            self.consume_byte(is_name_char, false)?,
         ])
     }