1919
2020# Type alias for node callback functions
2121NodeCallback = Callable [..., Dict [str , Union [str , int ]]]
22+ StyleCallback = Callable [[_Element , int , int ], None ]
2223
2324logger = getLogger ("html_to_text" )
2425
@@ -88,10 +89,12 @@ def __init__(
8889 node_parsed_callback : Union [NodeCallback , None ] = None ,
8990 startpos : int = 0 ,
9091 file : str = "" ,
92+ style_callback : Union [StyleCallback , None ] = None ,
9193 ) -> None :
9294 self .node_parsed_callback = node_parsed_callback
9395 self .startpos = startpos
9496 self .file = file
97+ self .style_callback = style_callback
9598 self .output = StringIO ()
9699 self .add = ""
97100 self .initial_space = False
@@ -106,6 +109,7 @@ def __init__(
106109 self .table_stack : list [dict [str , Union [str , int ]]] = []
107110 self .last_newline = False
108111 self .last_start = ""
112+ self .element_stack : list [tuple [_Element , int ]] = [] # Track (element, start_pos)
109113 self .link_start = 0
110114
111115 # Set up state machine using enum objects directly
@@ -198,6 +202,20 @@ def __init__(
198202
199203 LXMLParser .__init__ (self , item )
200204
205+ def parse_tag (self , item : _Element ) -> None :
206+ """Override to track element positions for style callback."""
207+ # Track start position for this element
208+ if self .style_callback is not None :
209+ start_pos = self .output .tell () + self .startpos
210+ self .element_stack .append ((item , start_pos ))
211+
212+ # Call parent's parse_tag
213+ super ().parse_tag (item )
214+
215+ # Pop from stack if we pushed
216+ if self .style_callback is not None :
217+ self .element_stack .pop ()
218+
201219 def is_in_pre (self , event_data = None ) -> bool :
202220 """Condition for state machine: check if we saved pre context before entering ignoring."""
203221 return self ._pre_context
@@ -365,6 +383,16 @@ def handle_endtag(self, tag: str, item: _Element) -> None: # type: ignore[overr
365383 elif tag in self ._table_tags and self .node_parsed_callback :
366384 self .table_stack [- 1 ]["end" ] = self .output .tell () + self .startpos
367385 self .table_stack .pop ()
386+
387+ # Call style callback if element has style attribute
388+ if self .style_callback is not None and item .get ('style' ) is not None :
389+ # Find this element's start position from stack
390+ if self .element_stack :
391+ element , start_pos = self .element_stack [- 1 ]
392+ if element == item :
393+ end_pos = self .output .tell () + self .startpos
394+ self .style_callback (item , start_pos , end_pos )
395+
368396 self .last_start = tag
369397
370398 def handle_data (self , data : str , start_tag : Optional [str ]) -> None : # type: ignore[override]
@@ -461,11 +489,12 @@ def html_to_text(
461489 node_parsed_callback : Union [NodeCallback , None ] = None ,
462490 startpos : int = 0 ,
463491 file : str = "" ,
492+ style_callback : Union [StyleCallback , None ] = None ,
464493) -> str :
465494 if isinstance (item , str ):
466495 item = tree_from_string (item )
467496 lxml .html .xhtml_to_html (item ) # type: ignore[arg-type]
468- parser = HTMLParser (item , node_parsed_callback , startpos , file )
497+ parser = HTMLParser (item , node_parsed_callback , startpos , file , style_callback )
469498 text = parser .output .getvalue ()
470499 if parser .last_page is not None :
471500 parser .last_page ["end" ] = parser .output .tell ()
0 commit comments