SreekarB commited on
Commit
3417890
·
verified ·
1 Parent(s): a7f7808

Upload visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +33 -3
visualization.py CHANGED
@@ -7,13 +7,43 @@ def plot_fc_matrices(original, reconstructed, generated):
7
 
8
  vmin, vmax = -1, 1
9
 
10
- im1 = axes[0].imshow(original, cmap='RdBu_r', vmin=vmin, vmax=vmax)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  axes[0].set_title('Original FC')
12
 
13
- im2 = axes[1].imshow(reconstructed, cmap='RdBu_r', vmin=vmin, vmax=vmax)
14
  axes[1].set_title('Reconstructed FC')
15
 
16
- im3 = axes[2].imshow(generated, cmap='RdBu_r', vmin=vmin, vmax=vmax)
17
  axes[2].set_title('Generated FC')
18
 
19
  for ax, im in zip(axes, [im1, im2, im3]):
 
7
 
8
  vmin, vmax = -1, 1
9
 
10
+ # Convert 1D arrays to 2D matrices if needed
11
+ def vector_to_matrix(vector):
12
+ """Convert upper triangular vector to full matrix"""
13
+ if len(vector.shape) == 1:
14
+ # Calculate the matrix size based on vector length
15
+ # For a vector of length n, the matrix size is (-1 + sqrt(1 + 8*n))/2
16
+ n = len(vector)
17
+ matrix_size = int((-1 + np.sqrt(1 + 8*n)) / 2)
18
+
19
+ # Create empty matrix
20
+ matrix = np.zeros((matrix_size, matrix_size))
21
+
22
+ # Fill upper triangle
23
+ idx = 0
24
+ for i in range(matrix_size):
25
+ for j in range(i+1, matrix_size):
26
+ matrix[i, j] = vector[idx]
27
+ idx += 1
28
+
29
+ # Make symmetric
30
+ matrix = matrix + matrix.T
31
+
32
+ return matrix
33
+ return vector
34
+
35
+ # Convert inputs to matrices if needed
36
+ original_mat = vector_to_matrix(original)
37
+ reconstructed_mat = vector_to_matrix(reconstructed)
38
+ generated_mat = vector_to_matrix(generated)
39
+
40
+ im1 = axes[0].imshow(original_mat, cmap='RdBu_r', vmin=vmin, vmax=vmax)
41
  axes[0].set_title('Original FC')
42
 
43
+ im2 = axes[1].imshow(reconstructed_mat, cmap='RdBu_r', vmin=vmin, vmax=vmax)
44
  axes[1].set_title('Reconstructed FC')
45
 
46
+ im3 = axes[2].imshow(generated_mat, cmap='RdBu_r', vmin=vmin, vmax=vmax)
47
  axes[2].set_title('Generated FC')
48
 
49
  for ax, im in zip(axes, [im1, im2, im3]):