Conversation
|
Thanks for your contribution! |
| WHERE s.deleted = 0 | ||
| AND s.sample_type != 'full_graph' | ||
| ) sub | ||
| WHERE sub.rn = 1 |
There was a problem hiding this comment.
去重吧,防止一个sample被重复选取
|
|
||
| def get_v2_group_members(candidates: list[CandidateGraph], num_dtypes: int): | ||
| # Index candidates by op_seq | ||
| by_op_seq = defaultdict(list) |
| b.input_shapes_bucket_id, | ||
| b.input_dtypes_bucket_id, | ||
| s.graph_hash, | ||
| ROW_NUMBER() OVER ( |
There was a problem hiding this comment.
graph_hash不需要了吧?ROW_NUMBER在这里的作用是什么?
There was a problem hiding this comment.
在每个 (op_seq, shapes, dtypes) 分区内,按创建时间排序编号,然后只取 rn = 1(最早的那条)。作用是桶内去重:同一个桶里可能有多个样本,只保留一个代表。
不过现在代码改了很多
| """ | ||
|
|
||
| # Index candidates by op_seq | ||
| by_op_seq = defaultdict(list) |
There was a problem hiding this comment.
candidates_by_op_seq,润色一下
| for c in candidates: | ||
| by_op_seq[c.op_seq_bucket_id].append(c) | ||
|
|
||
| rule3_selected_uids = set() |
There was a problem hiding this comment.
Pull request overview
This PR enhances the SQLite grouping insertion script by adding “group by v2” rules to generate additional graph_net_sample_groups beyond the existing v1 bucket-based grouping.
Changes:
- Refactors
graph_net_sample_groups_insert.pyinto clearer query/generation/insert phases with per-rule stats output. - Adds v2 grouping logic (Rule 4 dtype coverage + Rule 3 sparse sampling) controlled by
--num_dtypes. - Updates DB access to use a read-only
query()path viasqlite3alongside SQLAlchemy inserts.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| SELECT | ||
| sub.sample_uid, | ||
| sub.op_seq_bucket_id, | ||
| sub.input_shapes_bucket_id, | ||
| sub.sample_type, | ||
| group_concat(sub.sample_uid, ',') AS all_uids | ||
| FROM ( | ||
| SELECT | ||
| s.uuid AS sample_uid, | ||
| s.sample_type, | ||
| b.op_seq_bucket_id, | ||
| b.input_shapes_bucket_id | ||
| FROM graph_sample s | ||
| JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid | ||
| ORDER BY s.create_at ASC, s.uuid ASC | ||
| ) sub | ||
| GROUP BY sub.sample_type, sub.op_seq_bucket_id, sub.input_shapes_bucket_id; |
There was a problem hiding this comment.
query_bucket_groups selects sub.sample_uid as the bucket "head" without aggregating it or using a deterministic window function. In SQLite, selecting a non-GROUP BY column in an aggregate query can return an arbitrary row, so head_uid (and thus Rule 2 heads) may be nondeterministic. Consider selecting the head via MIN(...)/MAX(...) on a stable key, or using a window function to pick the first row by (create_at, uuid) and aggregating the rest separately.
| SELECT | |
| sub.sample_uid, | |
| sub.op_seq_bucket_id, | |
| sub.input_shapes_bucket_id, | |
| sub.sample_type, | |
| group_concat(sub.sample_uid, ',') AS all_uids | |
| FROM ( | |
| SELECT | |
| s.uuid AS sample_uid, | |
| s.sample_type, | |
| b.op_seq_bucket_id, | |
| b.input_shapes_bucket_id | |
| FROM graph_sample s | |
| JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid | |
| ORDER BY s.create_at ASC, s.uuid ASC | |
| ) sub | |
| GROUP BY sub.sample_type, sub.op_seq_bucket_id, sub.input_shapes_bucket_id; | |
| WITH buckets AS ( | |
| SELECT | |
| s.uuid AS sample_uid, | |
| s.sample_type, | |
| b.op_seq_bucket_id, | |
| b.input_shapes_bucket_id, | |
| FIRST_VALUE(s.uuid) OVER ( | |
| PARTITION BY s.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id | |
| ORDER BY s.create_at ASC, s.uuid ASC | |
| ) AS head_uid | |
| FROM graph_sample s | |
| JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid | |
| ) | |
| SELECT | |
| MIN(head_uid) AS head_uid, | |
| op_seq_bucket_id, | |
| input_shapes_bucket_id, | |
| sample_type, | |
| group_concat(sample_uid, ',') AS all_uids | |
| FROM buckets | |
| GROUP BY sample_type, op_seq_bucket_id, input_shapes_bucket_id; |
| b.op_seq_bucket_id, | ||
| b.input_shapes_bucket_id | ||
| FROM graph_sample s | ||
| JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid |
There was a problem hiding this comment.
query_bucket_groups doesn’t filter out deleted samples or full_graph samples (unlike query_v2_candidates). If graph_net_sample_buckets contains rows for deleted/full_graph samples (e.g., from older runs), this script will generate groups for data that is supposed to be excluded. Add WHERE s.deleted = 0 AND s.sample_type != 'full_graph' (or equivalent) to keep v1/v2 selection consistent.
| JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid | |
| JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid | |
| WHERE s.deleted = 0 | |
| AND s.sample_type != 'full_graph' |
PR Category
Feature Enhancement
Description
新增group分组规则