@@ -139,14 +139,16 @@ def tokenize_trajectory(
139139 # Find the index of the last assistant message
140140 last_assistant_index = - 1
141141 for i , message in enumerate (history .messages_and_choices ):
142- if isinstance (message , dict ):
143- # Message dict
144- if message ["role" ] == "assistant" and allow_training_without_logprobs :
145- last_assistant_index = i
146- else :
147- # Choice object
148- if message .logprobs is not None or allow_training_without_logprobs :
149- last_assistant_index = i
142+ if (
143+ isinstance (message , dict )
144+ and message ["role" ] == "assistant"
145+ and allow_training_without_logprobs
146+ ):
147+ last_assistant_index = i
148+ elif not isinstance (message , dict ) and (
149+ message .logprobs or allow_training_without_logprobs
150+ ):
151+ last_assistant_index = i
150152 # If there are no trainable assistant messages, return None
151153 if last_assistant_index == - 1 :
152154 return None
@@ -187,7 +189,7 @@ def tokenize_trajectory(
187189 (
188190 message_or_choice
189191 if isinstance (message_or_choice , dict )
190- and message_or_choice ["role" ] ! = "assistant"
192+ and not message_or_choice ["role" ] = = "assistant"
191193 else {
192194 "role" : "assistant" ,
193195 "content" : sentinal_token ,
@@ -203,7 +205,7 @@ def tokenize_trajectory(
203205 assistant_mask : list [int ] = [0 ] * len (token_ids )
204206 logprobs = [float ("nan" )] * len (token_ids )
205207 for message in messages_and_choices :
206- if isinstance (message , dict ) and message ["role" ] ! = "assistant" :
208+ if isinstance (message , dict ) and not message ["role" ] = = "assistant" :
207209 continue
208210 start = token_ids .index (sentinal_token_id )
209211 end = start + 1
@@ -212,7 +214,6 @@ def tokenize_trajectory(
212214 except IndexError :
213215 end_token_id = None
214216 if isinstance (message , dict ):
215- # Message dict
216217 content = message .get ("content" )
217218 assert isinstance (content , str )
218219 content_token_ids = tokenizer .encode (
@@ -223,7 +224,6 @@ def tokenize_trajectory(
223224 logprobs [start :end ] = [float ("nan" )] * len (content_token_ids )
224225 assistant_mask [start :end ] = [1 ] * len (content_token_ids )
225226 else :
226- # Choice object
227227 choice = message
228228 assert choice .logprobs or allow_training_without_logprobs , (
229229 "Chat completion choices must have logprobs"
0 commit comments