3535import java .io .File ;
3636import java .io .FileInputStream ;
3737import java .io .FileOutputStream ;
38+ import java .io .IOException ;
39+ import java .io .InputStream ;
40+ import java .io .InvalidClassException ;
3841import java .io .ObjectInputStream ;
3942import java .io .ObjectOutputStream ;
43+ import java .io .ObjectStreamClass ;
4044import java .util .LinkedList ;
4145import java .util .Queue ;
4246import java .util .concurrent .ConcurrentHashMap ;
@@ -63,13 +67,8 @@ private StateSaver() {
6367 * @param context used to get the available cache dir
6468 */
6569 public static void init (final Context context ) {
66- final File externalCacheDir = context .getExternalCacheDir ();
67- if (externalCacheDir != null ) {
68- cacheDirPath = externalCacheDir .getAbsolutePath ();
69- }
70- if (TextUtils .isEmpty (cacheDirPath )) {
71- cacheDirPath = context .getCacheDir ().getAbsolutePath ();
72- }
70+ // Use internal cache directory to prevent other apps from accessing/modifying the state
71+ cacheDirPath = context .getCacheDir ().getAbsolutePath ();
7372 }
7473
7574 /**
@@ -129,7 +128,7 @@ private static SavedState tryToRestore(@NonNull final SavedState savedState,
129128 }
130129
131130 try (FileInputStream fileInputStream = new FileInputStream (file );
132- ObjectInputStream inputStream = new ObjectInputStream (fileInputStream )) {
131+ ObjectInputStream inputStream = new ValidatingObjectInputStream (fileInputStream )) {
133132 //noinspection unchecked
134133 savedObjects = (Queue <Object >) inputStream .readObject ();
135134 }
@@ -310,6 +309,34 @@ public static void clearStateFiles() {
310309 }
311310 }
312311
312+ private static final class ValidatingObjectInputStream extends ObjectInputStream {
313+ ValidatingObjectInputStream (final InputStream in ) throws IOException {
314+ super (in );
315+ }
316+
317+ @ Override
318+ protected Class <?> resolveClass (final ObjectStreamClass desc )
319+ throws IOException , ClassNotFoundException {
320+ final String name = desc .getName ();
321+ if (!isSafe (name )) {
322+ throw new InvalidClassException ("Unauthorized deserialization attempt" , name );
323+ }
324+ return super .resolveClass (desc );
325+ }
326+
327+ private boolean isSafe (final String name ) {
328+ return name .startsWith ("java.lang." )
329+ || name .startsWith ("java.util." )
330+ || name .startsWith ("org.schabi.newpipe." )
331+ || name .startsWith ("[Ljava.lang." )
332+ || name .startsWith ("[Ljava.util." )
333+ || name .startsWith ("[Lorg.schabi.newpipe." )
334+ || name .equals ("[Z" ) || name .equals ("[B" ) || name .equals ("[C" )
335+ || name .equals ("[S" ) || name .equals ("[I" ) || name .equals ("[J" )
336+ || name .equals ("[F" ) || name .equals ("[D" );
337+ }
338+ }
339+
313340 /**
314341 * Used for describing how to save/read the objects.
315342 * <p>
0 commit comments