File size: 5,725 Bytes
3c7c02f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
function render({ model, el }) {
  const container = document.createElement("div");
  container.className = "grpo-gdpo-root";
  el.appendChild(container);

  // Helper functions for statistics
  function mean(arr) {
    if (arr.length === 0) return 0;
    return arr.reduce((a, b) => a + b, 0) / arr.length;
  }

  function std(arr) {
    if (arr.length <= 1) return 0;
    const m = mean(arr);
    // Use sample std (n-1) to match the paper
    const variance = arr.reduce((acc, val) => acc + (val - m) ** 2, 0) / (arr.length - 1);
    return Math.sqrt(variance);
  }

  function normalize(arr) {
    const m = mean(arr);
    const s = std(arr);
    if (s === 0) return arr.map(() => 0);
    return arr.map((v) => (v - m) / s);
  }

  // Calculate GRPO advantage (normalize total reward)
  function calcGrpoAdvantages(rewards) {
    const totals = rewards.map(
      (r) => r.correctness + r.style + r.conciseness
    );
    return normalize(totals);
  }

  // Calculate GDPO advantage (normalize each dimension, then sum)
  function calcGdpoAdvantages(rewards) {
    const correctness = rewards.map((r) => r.correctness);
    const style = rewards.map((r) => r.style);
    const conciseness = rewards.map((r) => r.conciseness);

    const normCorrectness = normalize(correctness);
    const normStyle = normalize(style);
    const normConciseness = normalize(conciseness);

    return rewards.map(
      (_, i) => normCorrectness[i] + normStyle[i] + normConciseness[i]
    );
  }

  function formatNumber(n) {
    if (n === 0) return "0.000";
    return n.toFixed(3);
  }

  function draw() {
    const rewards = model.get("rewards") || [];

    const grpoAdvantages = calcGrpoAdvantages(rewards);
    const gdpoAdvantages = calcGdpoAdvantages(rewards);

    container.innerHTML = "";

    // Create table
    const table = document.createElement("table");
    table.className = "grpo-gdpo-table";

    // Header row
    const thead = document.createElement("thead");
    const headerRow = document.createElement("tr");
    const headers = [
      "",
      "Correctness",
      "Style",
      "Conciseness",
      "Total",
      "GRPO Adv",
      "GDPO Adv",
      "Difference",
    ];
    headers.forEach((h) => {
      const th = document.createElement("th");
      th.textContent = h;
      headerRow.appendChild(th);
    });
    thead.appendChild(headerRow);
    table.appendChild(thead);

    // Body rows
    const tbody = document.createElement("tbody");
    rewards.forEach((reward, rowIndex) => {
      const row = document.createElement("tr");

      // Rollout label
      const labelCell = document.createElement("td");
      labelCell.className = "rollout-label";
      labelCell.textContent = `Rollout ${rowIndex}`;
      row.appendChild(labelCell);

      // Reward cells (clickable)
      ["correctness", "style", "conciseness"].forEach((dim) => {
        const cell = document.createElement("td");
        cell.className = "reward-cell";
        cell.dataset.value = reward[dim];
        cell.textContent = reward[dim];
        cell.addEventListener("click", () => {
          const newRewards = [...rewards];
          newRewards[rowIndex] = {
            ...newRewards[rowIndex],
            [dim]: reward[dim] === 1 ? 0 : 1,
          };
          model.set("rewards", newRewards);
          model.save_changes();
        });
        row.appendChild(cell);
      });

      // Total
      const total = reward.correctness + reward.style + reward.conciseness;
      const totalCell = document.createElement("td");
      totalCell.className = "computed-cell";
      totalCell.textContent = total;
      row.appendChild(totalCell);

      // GRPO Advantage
      const grpoCell = document.createElement("td");
      grpoCell.className = "computed-cell";
      grpoCell.textContent = formatNumber(grpoAdvantages[rowIndex]);
      row.appendChild(grpoCell);

      // GDPO Advantage
      const gdpoCell = document.createElement("td");
      gdpoCell.className = "computed-cell";
      gdpoCell.textContent = formatNumber(gdpoAdvantages[rowIndex]);
      row.appendChild(gdpoCell);

      // Difference
      const diff = gdpoAdvantages[rowIndex] - grpoAdvantages[rowIndex];
      const diffCell = document.createElement("td");
      diffCell.className = "diff-cell";
      if (Math.abs(diff) > 0.001) {
        diffCell.classList.add("has-diff");
      }
      diffCell.textContent = formatNumber(diff);
      row.appendChild(diffCell);

      tbody.appendChild(row);
    });
    table.appendChild(tbody);

    container.appendChild(table);

    // Add/Remove buttons
    const buttonRow = document.createElement("div");
    buttonRow.className = "button-row";

    const addBtn = document.createElement("button");
    addBtn.textContent = "+ Add Rollout";
    addBtn.className = "action-btn";
    addBtn.addEventListener("click", () => {
      const newRewards = [
        ...rewards,
        { correctness: 0, style: 0, conciseness: 0 },
      ];
      model.set("rewards", newRewards);
      model.save_changes();
    });
    buttonRow.appendChild(addBtn);

    const removeBtn = document.createElement("button");
    removeBtn.textContent = "- Remove Last";
    removeBtn.className = "action-btn";
    removeBtn.disabled = rewards.length <= 2;
    removeBtn.addEventListener("click", () => {
      if (rewards.length > 2) {
        const newRewards = rewards.slice(0, -1);
        model.set("rewards", newRewards);
        model.save_changes();
      }
    });
    buttonRow.appendChild(removeBtn);

    container.appendChild(buttonRow);
  }

  // Listen for changes
  model.on("change:rewards", draw);

  // Initial render
  draw();

  return () => {
    // Cleanup
  };
}

export default { render };